train

# 实例化一个模型,在模型后调用.train(True),说明我们将该模型设置为训练模式
def train(self: T, mode: bool = True) -> T:
    r"""Sets the module in training mode.

    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
    etc.

    Args:
        mode (bool): whether to set training mode (``True``) or evaluation
                     mode (``False``). Default: ``True``.

    Returns:
        Module: self
    """
    # 首先判断mode是否为bool类型
    if not isinstance(mode, bool):
        raise ValueError("training mode is expected to be boolean")
    self.training = mode  # 如果传入了True,则self.training为True,且子模块递归都设置为训练模式
    for module in self.children():
        module.train(mode)
    return self

eval()

def eval(self: T) -> T:
    r"""Sets the module in evaluation mode.

    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
    etc.

    This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

    See :ref:`locally-disable-grad-doc` for a comparison between
    `.eval()` and several similar mechanisms that may be confused with it.

    Returns:
        Module: self
    """
    return self.train(False)

requires_grad_

# 是否需要自动微分 module的函数
def requires_grad_(self: T, requires_grad: bool = True) -> T:
    r"""Change if autograd should record operations on parameters in this
    module.

    This method sets the parameters' :attr:`requires_grad` attributes
    in-place.

    This method is helpful for freezing part of the module for finetuning
    or training parts of a model individually (e.g., GAN training).

    See :ref:`locally-disable-grad-doc` for a comparison between
    `.requires_grad_()` and several similar mechanisms that may be confused with it.

    Args:
        requires_grad (bool): whether autograd should record operations on
                              parameters in this module. Default: ``True``.

    Returns:
        Module: self
    """
    # 对当前模型的所有参数遍历
    for p in self.parameters():
        # 将每个参数的requires_grad_设置为requires_grad  tensor函数
        p.requires_grad_(requires_grad)
    return self

zero_grad

在训练的每一步开始之前调用优化器的zero_gradPyTorch会对每个参数的梯度做一个累积,第一次算梯度正确的,第二次计算前没有调用zero_grad,在第二次计算时梯度变为两倍。对优化器调用就行。

def zero_grad(self, set_to_none: bool = True) -> None:
    r"""Resets gradients of all model parameters. See similar function
    under :class:`torch.optim.Optimizer` for more context.

    Args:
        set_to_none (bool): instead of setting to zero, set the grads to None.
            See :meth:`torch.optim.Optimizer.zero_grad` for details.
    """
    if getattr(self, '_is_replica', False):
        warnings.warn(
            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
            "The parameters are copied (in a differentiable manner) from the original module. "
            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
            "If you need gradients in your forward method, consider using autograd.grad instead.")

    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

__repr\_\_

魔法函数,对模型进行字符串展示

def __repr__(self):
    # We treat the extra repr like the sub-module, one item per line
    extra_lines = []
    extra_repr = self.extra_repr()
    # empty string will be split into list ['']
    if extra_repr:
        extra_lines = extra_repr.split('\n')
    child_lines = []
    for key, module in self._modules.items():
        mod_str = repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append('(' + key + '): ' + mod_str)
    lines = extra_lines + child_lines

    main_str = self._get_name() + '('
    if lines:
        # simple one-liner info, which most builtin Modules will use
        if len(extra_lines) == 1 and not child_lines:
            main_str += extra_lines[0]
        else:
            main_str += '\n  ' + '\n  '.join(lines) + '\n'

    main_str += ')'
    return main_str
print(str(test_net))
'''
Net(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
'''

__dir\_\_

def __dir__(self):
    module_attrs = dir(self.__class__)  # 当前类自身的dir
    attrs = list(self.__dict__.keys())  # 当前字典的所有键值
    parameters = list(self._parameters.keys())  # 参数的键值
    modules = list(self._modules.keys())  # module的键值
    buffers = list(self._buffers.keys())  # buffer的键值
    keys = module_attrs + attrs + parameters + modules + buffers

    # Eliminate attrs that are not legal Python variable names
    keys = [key for key in keys if not key[0].isdigit()]

    return sorted(keys)
print(dir(test_net))

'''
['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_version', '_wrapped_call_impl', 'add_module', 'apply', 'batch_norm', 'bfloat16', 'buffers', 'call_super_init', 'children', 'compile', 'cpu', 'cuda', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'ipu', 'linear1', 'linear2', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_module', 'register_parameter', 'register_state_dict_pre_hook', 'requires_grad_', 'set_extra_state', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']

'''

Sequential

def __init__(self, *args):
    super().__init__()
    if len(args) == 1 and isinstance(args[0], OrderedDict):
        # 对字典进行遍历, 向当前Sequential中添加模块,按照给定的Key
        for key, module in args[0].items():
            self.add_module(key, module)
    else:
        # 直接传入module实例,也会添加模块,但是模块名称为数字形式,从0开始
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

示例

s = nn.Sequential(
    nn.Linear(2, 3),
    nn.Linear(3, 4)
)

print(s)
'''
Sequential(
  (0): Linear(in_features=2, out_features=3, bias=True)
  (1): Linear(in_features=3, out_features=4, bias=True)
)
'''
print(s._modules)
'''
OrderedDict([('0', Linear(in_features=2, out_features=3, bias=True)), ('1', Linear(in_features=3, out_features=4, bias=True))])

'''

forward

input依次的过每个模块

def forward(self, input):
    for module in self:
        input = module(input)
    return input

ModuleList

将很多子module放到列表中

Example::

    class MyModule(nn.Module):
        def __init__(self):
            super().__init__()
            self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

        def forward(self, x):
            # ModuleList can act as an iterable, or be indexed using ints
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
            return x

本文由 fmujie 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论

召唤看板娘