Save and Load the Model

PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
model = models.vgg16(weights='IMAGENET1K_V1')
# 保存模型的参数、优化器状态、batch_nomalization、drop_out等等一系列的buffer变量
# model.state_dict() 只保留模型的权重
torch.save(model.state_dict(), 'model_weights.pth')
# 每个model都会包含一个state_dict()状态字典:包含模型所有的参数和buffer变量
To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
# 推理模式:会影响到drop_out、batch_nomalization
model.eval()

同时保存优化器权重的method

import torch
import torch.nn as nn
import torch.optim as optim


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Collect all relevant information and build your dictionary.
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
# 不光要保存模型的权重,还要保存优化器的状态量、Loss、Epoch
torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

Remember to first initialize the model and optimizer, then load the dictionary locally.

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

MODULE SOURCE CODE

to

    def to(self, *args, **kwargs):
        r"""Moves and/or casts the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)
           :noindex:

        .. function:: to(dtype, non_blocking=False)
           :noindex:

        .. function:: to(tensor, non_blocking=False)
           :noindex:

        .. function:: to(memory_format=torch.channels_last)
           :noindex:

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
        floating point or complex :attr:`dtype`\ s. In addition, this method will
        only cast the floating point or complex parameters and buffers to :attr:`dtype`
        (if given). The integral parameters and buffers will be moved
        :attr:`device`, if that is given, but with dtypes unchanged. When
        :attr:`non_blocking` is set, it tries to convert/move asynchronously
        with respect to the host if possible, e.g., moving CPU Tensors with
        pinned memory to CUDA devices.

        See below for examples.

        .. note::
            This method modifies the module in-place.

        Args:
            device (:class:`torch.device`): the desired device of the parameters
                and buffers in this module
            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
                the parameters and buffers in this module
            tensor (torch.Tensor): Tensor whose dtype and device are the desired
                dtype and device for all parameters and buffers in this module
            memory_format (:class:`torch.memory_format`): the desired memory
                format for 4D parameters and buffers in this module (keyword
                only argument)

        Returns:
            Module: self

        Examples::

            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
            >>> linear = nn.Linear(2, 2)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1913, -0.3420],
                    [-0.5113, -0.2325]])
            >>> linear.to(torch.double)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1913, -0.3420],
                    [-0.5113, -0.2325]], dtype=torch.float64)
            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
            >>> gpu1 = torch.device("cuda:1")
            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1914, -0.3420],
                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
            >>> cpu = torch.device("cpu")
            >>> linear.to(cpu)
            Linear(in_features=2, out_features=2, bias=True)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.1914, -0.3420],
                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
            >>> linear.weight
            Parameter containing:
            tensor([[ 0.3741+0.j,  0.2382+0.j],
                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))
            tensor([[0.6122+0.j, 0.1150+0.j],
                    [0.6122+0.j, 0.1150+0.j],
                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """

        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:
            if not (dtype.is_floating_point or dtype.is_complex):
                raise TypeError('nn.Module.to only accepts floating point or complex '
                                f'dtypes, but got desired dtype={dtype}')
            if dtype.is_complex:
                warnings.warn(
                    "Complex modules are a new feature under active development whose design may change, "
                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "
                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
                    "if a complex module does not work as expected.")

        def convert(t):
            if convert_to_format is not None and t.dim() in (4, 5):
                return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
                            non_blocking, memory_format=convert_to_format)
            return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

        return self._apply(convert)

示例

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        # 调用父类=>Module的init函数
        super(Net, self).__init__()
        self.linear1 = nn.Linear(2, 3)
        self.linear2 = nn.Linear(3, 4)
        self.batch_norm = nn.BatchNorm2d(4)


test_net = Net()

linear1 = test_net._modules['linear1']
print(linear1)
'''
Linear(in_features=2, out_features=3, bias=True)
'''
weight1 = linear1.weight
print(weight1, weight1.dtype)
'''
Parameter containing:
tensor([[ 0.5884,  0.4096],
        [-0.0927,  0.3592],
        [ 0.6238, -0.1112]], requires_grad=True) 
torch.float32
'''
# 现在调用to函数将model中的浮点类型改为双精度类型
test_net.to(torch.double)
print(test_net._modules['linear1'].weight.dtype)
'''
torch.float64
'''
print(test_net._parameters())
'''
OrderedDict()
'''
# _parameters并不会对子模块进行遍历,只是对当前模块自身进行一个搜索:有无实例化一个nn.Parameter的对象,所以此处为空字典
print(test_net._buffers)
'''
OrderedDict()
'''
print(test_net.state_dict)
'''
OrderedDict([('linear1.weight', tensor([[ 0.1710,  0.2324],
        [ 0.5988, -0.0737],
        [-0.4426,  0.4938]], dtype=torch.float64)), ('linear1.bias', tensor([-0.6435,  0.2064,  0.4859], dtype=torch.float64)), ('linear2.weight', tensor([[-0.1337,  0.2572,  0.5521],
        [-0.1224,  0.2688,  0.3103],
        [-0.0960, -0.2258, -0.0193],
        [-0.0830, -0.3477, -0.1718]], dtype=torch.float64)), ('linear2.bias', tensor([-0.3606, -0.5329, -0.3330, -0.2828], dtype=torch.float64)), ('batch_norm.weight', tensor([1., 1., 1., 1.], dtype=torch.float64)), ('batch_norm.bias', tensor([0., 0., 0., 0.], dtype=torch.float64)), ('batch_norm.running_mean', tensor([0., 0., 0., 0.], dtype=torch.float64)), ('batch_norm.running_var', tensor([1., 1., 1., 1.], dtype=torch.float64)), ('batch_norm.num_batches_tracked', tensor(0))])

'''

__getattr\_\_

def __getattr__(self, name: str) -> Any:
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

_save_to_state_dict

将当前model的参数和buffers全部放到一个字典中,前述net.state_dict(),就是调用了module的state_dict,该方法又调用_save_to_state_dict

def _save_to_state_dict(self, destination, prefix, keep_vars):
    # 俩for循环, O(n)复杂度, 首先对当前model的parameters和buffers进行遍历,然后放到destination字典中
    for name, param in self._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.detach()
    for name, buf in self._buffers.items():
        if buf is not None and name not in self._non_persistent_buffers_set:
            destination[prefix + name] = buf if keep_vars else buf.detach()
    extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
    if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
        destination[extra_state_key] = self.get_extra_state()
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
    # TODO: Remove `args` and the parsing logic when BC allows.
    if len(args) > 0:
        if destination is None:
            destination = args[0]
        if len(args) > 1 and prefix == '':
            prefix = args[1]
        if len(args) > 2 and keep_vars is False:
            keep_vars = args[2]
        # DeprecationWarning is ignored by default
        warnings.warn(
            "Positional args are being deprecated, use kwargs instead. Refer to "
            "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
            " for details.")

    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()

    local_metadata = dict(version=self._version)
    if hasattr(destination, "_metadata"):
        destination._metadata[prefix[:-1]] = local_metadata

    for hook in self._state_dict_pre_hooks.values():
        hook(self, prefix, keep_vars)
    # 将当前模块参数和buffers存放在destination字典中
    self._save_to_state_dict(destination, prefix, keep_vars)
    # 对当前模块的子模块进行递归,将每个model的参数buffers都存放在destination字典中
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
    for hook in self._state_dict_hooks.values():
        hook_result = hook(self, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination

_load_from_state_dict

从一个state_stict中得到parameters、buffers

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        for hook in self._load_state_dict_pre_hooks.values():
            hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        # 先得到键,放在local_state中
        persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
        local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
        local_state = {k: v for k, v in local_name_params if v is not None}
        assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)

        for name, param in local_state.items():
            key = prefix + name
            if key in state_dict:
                input_param = state_dict[key]
                if not torch.overrides.is_tensor_like(input_param):
                    error_msgs.append(f'While copying the parameter named "{key}", '
                                      'expected torch.Tensor or Tensor-like object from checkpoint but '
                                      f'received {type(input_param)}'
                                      )
                    continue

                # This is used to avoid copying uninitialized parameters into
                # non-lazy modules, since they dont have the hook to do the checks
                # in such case, it will error when accessing the .shape attribute.
                is_param_lazy = torch.nn.parameter.is_lazy(param)
                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
                if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
                    input_param = input_param[0]

                if not is_param_lazy and input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                      'the shape in current model is {}.'
                                      .format(key, input_param.shape, param.shape))
                    continue

                if param.is_meta and not input_param.is_meta and not assign_to_params_buffers:
                    warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '
                                  'parameter in the current model, which is a no-op. (Did you mean to '
                                  'pass `assign=True` to assign items in the state dictionary to their '
                                  'corresponding key in the module instead of copying them in place?)')

                try:
                    with torch.no_grad():
                        if assign_to_params_buffers:
                            # Shape checks are already done above
                            if (isinstance(param, torch.nn.Parameter) and
                                    not isinstance(input_param, torch.nn.Parameter)):
                                setattr(self, name, torch.nn.Parameter(input_param))
                            else:
                                setattr(self, name, input_param)
                        else:
                            # 外部传来的字典,copy进param
                            param.copy_(input_param)
                except Exception as ex:
                    error_msgs.append(f'While copying the parameter named "{key}", '
                                      f'whose dimensions in the model are {param.size()} and '
                                      f'whose dimensions in the checkpoint are {input_param.size()}, '
                                      f'an exception occurred : {ex.args}.'
                                      )
            elif strict:
                missing_keys.append(key)

        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
        if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
            elif strict:
                missing_keys.append(extra_state_key)
        elif strict and (extra_state_key in state_dict):
            unexpected_keys.append(extra_state_key)

        if strict:
            for key in state_dict.keys():
                if key.startswith(prefix) and key != extra_state_key:
                    input_name = key[len(prefix):]
                    input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                    if input_name not in self._modules and input_name not in local_state:
                        unexpected_keys.append(key)

load_state_dict

# 递归的将checkpoint中值赋予到模型
def load_state_dict(self, state_dict: Mapping[str, Any],
                    strict: bool = True, assign: bool = False):
    if not isinstance(state_dict, Mapping):
        raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

    missing_keys: List[str] = []
    unexpected_keys: List[str] = []
    error_msgs: List[str] = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = OrderedDict(state_dict)
    if metadata is not None:
        # mypy isn't aware that "_metadata" exists in state_dict
        state_dict._metadata = metadata  # type: ignore[attr-defined]

    def load(module, local_state_dict, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        if assign:
            local_metadata['assign_to_params_buffers'] = assign
        module._load_from_state_dict(
            local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                child_prefix = prefix + name + '.'
                child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                load(child, child_state_dict, child_prefix)

        # Note that the hook can modify missing_keys and unexpected_keys.
        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
        for hook in module._load_state_dict_post_hooks.values():
            out = hook(module, incompatible_keys)
            assert out is None, (
                "Hooks registered with ``register_load_state_dict_post_hook`` are not"
                "expected to return new values, if incompatible_keys need to be modified,"
                "it should be done inplace."
            )

    load(self, state_dict)
    del load

    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in missing_keys)))

    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                           self.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

_named_members

# 通用的查找函数,可查找model的参数buffer本身等等
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
    r"""Helper method for yielding various names + members of modules."""
    memo = set()
    # 有对named_modules的调用,该函数返回所有的模块
    modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
    # 对所有的模块迭代
    for module_prefix, module in modules:
        members = get_members_fn(module)
        for k, v in members:
            if v is None or v in memo:
               continue
            if remove_duplicate:
               memo.add(v)
            name = module_prefix + ('.' if module_prefix else '') + k
            yield name, v

parameters

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        r"""Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:
            recurse (bool): if True, then yields parameters of this module
                and all submodules. Otherwise, yields only parameters that
                are direct members of this module.

        Yields:
            Parameter: module parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")
            >>> for param in model.parameters():
            >>>     print(type(param), param.size())
            <class 'torch.Tensor'> (20L,)
            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters(recurse=recurse):
            yield param

_parameters只返回当前model的参数不包含子model, parameters()迭代返回所有

for p in test_net.parameters():
    print(p)
'''
Parameter containing:
tensor([[ 0.2101, -0.1331],
        [-0.4235, -0.1817],
        [ 0.4881,  0.2107]], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([-0.3277,  0.2989, -0.3436], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([[-0.3632, -0.0776,  0.2335],
        [-0.2908, -0.5474, -0.3428],
        [-0.1266, -0.5220,  0.2134],
        [ 0.1425,  0.1375, -0.3167]], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([-0.4652,  0.5581, -0.0013, -0.1475], dtype=torch.float64,
       requires_grad=True)
Parameter containing:
tensor([1., 1., 1., 1.], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], dtype=torch.float64, requires_grad=True)
'''

named_parameters

def named_parameters(
        self,
        prefix: str = '',
        recurse: bool = True,
        remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
    r"""Returns an iterator over module parameters, yielding both the
    name of the parameter as well as the parameter itself.

    Args:
        prefix (str): prefix to prepend to all parameter names.
        recurse (bool): if True, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            parameters in the result. Defaults to True.

    Yields:
        (str, Parameter): Tuple containing the name and parameter

    Example::

        >>> # xdoctest: +SKIP("undefined vars")
        >>> for name, param in self.named_parameters():
        >>>     if name in ['bias']:
        >>>         print(param.size())

    """
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
    yield from gen

注:parameters()走的named_parametersnamed_parameters走的_named_members,传入的是lambda module: module._parameters.items(),返回的是model自身的参数,_named_members又对传入的模块的所有子模块迭代,所以parameters()最终能得到所有模块的参数。下述buffer同理

buffers

def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
    r"""Returns an iterator over module buffers.

    Args:
        recurse (bool): if True, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module.

    Yields:
        torch.Tensor: module buffer

    Example::

        >>> # xdoctest: +SKIP("undefined vars")
        >>> for buf in model.buffers():
        >>>     print(type(buf), buf.size())
        <class 'torch.Tensor'> (20L,)
        <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

    """
    for _, buf in self.named_buffers(recurse=recurse):
        yield buf

def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
    r"""Returns an iterator over module buffers, yielding both the
    name of the buffer as well as the buffer itself.

    Args:
        prefix (str): prefix to prepend to all buffer names.
        recurse (bool, optional): if True, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module. Defaults to True.
        remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

    Yields:
        (str, torch.Tensor): Tuple containing the name and buffer

    Example::

        >>> # xdoctest: +SKIP("undefined vars")
        >>> for name, buf in self.named_buffers():
        >>>     if name in ['running_var']:
        >>>         print(buf.size())

    """
    gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
    yield from gen

children

def children(self) -> Iterator['Module']:
    r"""Returns an iterator over immediate children modules.

    Yields:
        Module: a child module
    """
    for name, module in self.named_children():
        yield module

def named_children(self) -> Iterator[Tuple[str, 'Module']]:
    r"""Returns an iterator over immediate children modules, yielding both
    the name of the module as well as the module itself.

    Yields:
        (str, Module): Tuple containing a name and child module

    Example::

        >>> # xdoctest: +SKIP("undefined vars")
        >>> for name, module in model.named_children():
        >>>     if name in ['conv4', 'conv5']:
        >>>         print(module)

    """
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module
for p in test_net.named_children():
    print(p)
'''
('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
'''
print(test_net._modules)
'''
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
'''

modules

def modules(self) -> Iterator['Module']:
    r"""Returns an iterator over all modules in the network.

    Yields:
        Module: a module in the network

    Note:
        Duplicate modules are returned only once. In the following
        example, ``l`` will be returned only once.

    Example::

        >>> l = nn.Linear(2, 2)
        >>> net = nn.Sequential(l, l)
        >>> for idx, m in enumerate(net.modules()):
        ...     print(idx, '->', m)

        0 -> Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        )
        1 -> Linear(in_features=2, out_features=2, bias=True)

    """
    for _, module in self.named_modules():
        yield module

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
    r"""Returns an iterator over all modules in the network, yielding
    both the name of the module as well as the module itself.

    Args:
        memo: a memo to store the set of modules already added to the result
        prefix: a prefix that will be added to the name of the module
        remove_duplicate: whether to remove the duplicated module instances in the result
            or not

    Yields:
        (str, Module): Tuple of name and module

    Note:
        Duplicate modules are returned only once. In the following
        example, ``l`` will be returned only once.

    Example::

        >>> l = nn.Linear(2, 2)
        >>> net = nn.Sequential(l, l)
        >>> for idx, m in enumerate(net.named_modules()):
        ...     print(idx, '->', m)

        0 -> ('', Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        ))
        1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

    """
        
    if memo is None:
        memo = set()
    if self not in memo:
        if remove_duplicate:
            memo.add(self)
        yield prefix, self  # 首先返回自身, 再对子模块遍历
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            yield from module.named_modules(memo, submodule_prefix, remove_duplicate)
for p in test_net.named_modules():
    print(p)
    # 返回+自身

print(test_net._modules)
# 只是返回子模块
'''
('', Net(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
))
('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
'''

本文由 fmujie 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论

召唤看板娘