train
# 实例化一个模型,在模型后调用.train(True),说明我们将该模型设置为训练模式
def train(self: T, mode: bool = True) -> T:
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
# 首先判断mode是否为bool类型
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode # 如果传入了True,则self.training为True,且子模块递归都设置为训练模式
for module in self.children():
module.train(mode)
return self
eval()
def eval(self: T) -> T:
r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
"""
return self.train(False)
requires_grad_
# 是否需要自动微分 module的函数
def requires_grad_(self: T, requires_grad: bool = True) -> T:
r"""Change if autograd should record operations on parameters in this
module.
This method sets the parameters' :attr:`requires_grad` attributes
in-place.
This method is helpful for freezing part of the module for finetuning
or training parts of a model individually (e.g., GAN training).
See :ref:`locally-disable-grad-doc` for a comparison between
`.requires_grad_()` and several similar mechanisms that may be confused with it.
Args:
requires_grad (bool): whether autograd should record operations on
parameters in this module. Default: ``True``.
Returns:
Module: self
"""
# 对当前模型的所有参数遍历
for p in self.parameters():
# 将每个参数的requires_grad_设置为requires_grad tensor函数
p.requires_grad_(requires_grad)
return self
zero_grad
在训练的每一步开始之前调用优化器的zero_grad
,PyTorch
会对每个参数的梯度做一个累积,第一次算梯度正确的,第二次计算前没有调用zero_grad
,在第二次计算时梯度变为两倍。对优化器调用就行。
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets gradients of all model parameters. See similar function
under :class:`torch.optim.Optimizer` for more context.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
if getattr(self, '_is_replica', False):
warnings.warn(
"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
"The parameters are copied (in a differentiable manner) from the original module. "
"This means they are not leaf nodes in autograd and so don't accumulate gradients. "
"If you need gradients in your forward method, consider using autograd.grad instead.")
for p in self.parameters():
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
__repr\_\_
魔法函数,对模型进行字符串展示
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
print(str(test_net))
'''
Net(
(linear1): Linear(in_features=2, out_features=3, bias=True)
(linear2): Linear(in_features=3, out_features=4, bias=True)
(batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
'''
__dir\_\_
def __dir__(self):
module_attrs = dir(self.__class__) # 当前类自身的dir
attrs = list(self.__dict__.keys()) # 当前字典的所有键值
parameters = list(self._parameters.keys()) # 参数的键值
modules = list(self._modules.keys()) # module的键值
buffers = list(self._buffers.keys()) # buffer的键值
keys = module_attrs + attrs + parameters + modules + buffers
# Eliminate attrs that are not legal Python variable names
keys = [key for key in keys if not key[0].isdigit()]
return sorted(keys)
print(dir(test_net))
'''
['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_version', '_wrapped_call_impl', 'add_module', 'apply', 'batch_norm', 'bfloat16', 'buffers', 'call_super_init', 'children', 'compile', 'cpu', 'cuda', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'ipu', 'linear1', 'linear2', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_module', 'register_parameter', 'register_state_dict_pre_hook', 'requires_grad_', 'set_extra_state', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']
'''
Sequential
def __init__(self, *args):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
# 对字典进行遍历, 向当前Sequential中添加模块,按照给定的Key
for key, module in args[0].items():
self.add_module(key, module)
else:
# 直接传入module实例,也会添加模块,但是模块名称为数字形式,从0开始
for idx, module in enumerate(args):
self.add_module(str(idx), module)
示例
s = nn.Sequential(
nn.Linear(2, 3),
nn.Linear(3, 4)
)
print(s)
'''
Sequential(
(0): Linear(in_features=2, out_features=3, bias=True)
(1): Linear(in_features=3, out_features=4, bias=True)
)
'''
print(s._modules)
'''
OrderedDict([('0', Linear(in_features=2, out_features=3, bias=True)), ('1', Linear(in_features=3, out_features=4, bias=True))])
'''
forward
input
依次的过每个模块
def forward(self, input):
for module in self:
input = module(input)
return input
ModuleList
将很多子module放到列表中
Example::
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
还不快抢沙发