Save and Load the Model
PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
model = models.vgg16(weights='IMAGENET1K_V1')
# 保存模型的参数、优化器状态、batch_nomalization、drop_out等等一系列的buffer变量
# model.state_dict() 只保留模型的权重
torch.save(model.state_dict(), 'model_weights.pth')
# 每个model都会包含一个state_dict()状态字典:包含模型所有的参数和buffer变量
To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
# 推理模式:会影响到drop_out、batch_nomalization
model.eval()
同时保存优化器权重的method
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Collect all relevant information and build your dictionary.
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
# 不光要保存模型的权重,还要保存优化器的状态量、Loss、Epoch
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
Remember to first initialize the model and optimizer, then load the dictionary locally.
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
MODULE SOURCE CODE
to
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
:noindex:
.. function:: to(dtype, non_blocking=False)
:noindex:
.. function:: to(tensor, non_blocking=False)
:noindex:
.. function:: to(memory_format=torch.channels_last)
:noindex:
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point or complex :attr:`dtype`\ s. In addition, this method will
only cast the floating point or complex parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
memory_format (:class:`torch.memory_format`): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
Returns:
Module: self
Examples::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)
>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j, 0.2382+0.j],
[ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
"""
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if dtype is not None:
if not (dtype.is_floating_point or dtype.is_complex):
raise TypeError('nn.Module.to only accepts floating point or complex '
f'dtypes, but got desired dtype={dtype}')
if dtype.is_complex:
warnings.warn(
"Complex modules are a new feature under active development whose design may change, "
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"if a complex module does not work as expected.")
def convert(t):
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking, memory_format=convert_to_format)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
return self._apply(convert)
示例
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
# 调用父类=>Module的init函数
super(Net, self).__init__()
self.linear1 = nn.Linear(2, 3)
self.linear2 = nn.Linear(3, 4)
self.batch_norm = nn.BatchNorm2d(4)
test_net = Net()
linear1 = test_net._modules['linear1']
print(linear1)
'''
Linear(in_features=2, out_features=3, bias=True)
'''
weight1 = linear1.weight
print(weight1, weight1.dtype)
'''
Parameter containing:
tensor([[ 0.5884, 0.4096],
[-0.0927, 0.3592],
[ 0.6238, -0.1112]], requires_grad=True)
torch.float32
'''
# 现在调用to函数将model中的浮点类型改为双精度类型
test_net.to(torch.double)
print(test_net._modules['linear1'].weight.dtype)
'''
torch.float64
'''
print(test_net._parameters())
'''
OrderedDict()
'''
# _parameters并不会对子模块进行遍历,只是对当前模块自身进行一个搜索:有无实例化一个nn.Parameter的对象,所以此处为空字典
print(test_net._buffers)
'''
OrderedDict()
'''
print(test_net.state_dict)
'''
OrderedDict([('linear1.weight', tensor([[ 0.1710, 0.2324],
[ 0.5988, -0.0737],
[-0.4426, 0.4938]], dtype=torch.float64)), ('linear1.bias', tensor([-0.6435, 0.2064, 0.4859], dtype=torch.float64)), ('linear2.weight', tensor([[-0.1337, 0.2572, 0.5521],
[-0.1224, 0.2688, 0.3103],
[-0.0960, -0.2258, -0.0193],
[-0.0830, -0.3477, -0.1718]], dtype=torch.float64)), ('linear2.bias', tensor([-0.3606, -0.5329, -0.3330, -0.2828], dtype=torch.float64)), ('batch_norm.weight', tensor([1., 1., 1., 1.], dtype=torch.float64)), ('batch_norm.bias', tensor([0., 0., 0., 0.], dtype=torch.float64)), ('batch_norm.running_mean', tensor([0., 0., 0., 0.], dtype=torch.float64)), ('batch_norm.running_var', tensor([1., 1., 1., 1.], dtype=torch.float64)), ('batch_norm.num_batches_tracked', tensor(0))])
'''
__getattr\_\_
def __getattr__(self, name: str) -> Any:
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
_save_to_state_dict
将当前model的参数和buffers全部放到一个字典中,前述net.state_dict(),就是调用了module的state_dict,该方法又调用_save_to_state_dict
def _save_to_state_dict(self, destination, prefix, keep_vars):
# 俩for循环, O(n)复杂度, 首先对当前model的parameters和buffers进行遍历,然后放到destination字典中
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
warnings.warn(
"Positional args are being deprecated, use kwargs instead. Refer to "
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
" for details.")
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
for hook in self._state_dict_pre_hooks.values():
hook(self, prefix, keep_vars)
# 将当前模块参数和buffers存放在destination字典中
self._save_to_state_dict(destination, prefix, keep_vars)
# 对当前模块的子模块进行递归,将每个model的参数buffers都存放在destination字典中
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
_load_from_state_dict
从一个state_stict中得到parameters、buffers
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
# 先得到键,放在local_state中
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(f'While copying the parameter named "{key}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
f'received {type(input_param)}'
)
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
if param.is_meta and not input_param.is_meta and not assign_to_params_buffers:
warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '
'parameter in the current model, which is a no-op. (Did you mean to '
'pass `assign=True` to assign items in the state dictionary to their '
'corresponding key in the module instead of copying them in place?)')
try:
with torch.no_grad():
if assign_to_params_buffers:
# Shape checks are already done above
if (isinstance(param, torch.nn.Parameter) and
not isinstance(input_param, torch.nn.Parameter)):
setattr(self, name, torch.nn.Parameter(input_param))
else:
setattr(self, name, input_param)
else:
# 外部传来的字典,copy进param
param.copy_(input_param)
except Exception as ex:
error_msgs.append(f'While copying the parameter named "{key}", '
f'whose dimensions in the model are {param.size()} and '
f'whose dimensions in the checkpoint are {input_param.size()}, '
f'an exception occurred : {ex.args}.'
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
load_state_dict
# 递归的将checkpoint中值赋予到模型
def load_state_dict(self, state_dict: Mapping[str, Any],
strict: bool = True, assign: bool = False):
if not isinstance(state_dict, Mapping):
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, local_state_dict, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata['assign_to_params_buffers'] = assign
module._load_from_state_dict(
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + '.'
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
load(child, child_state_dict, child_prefix)
# Note that the hook can modify missing_keys and unexpected_keys.
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible_keys)
assert out is None, (
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
"expected to return new values, if incompatible_keys need to be modified,"
"it should be done inplace."
)
load(self, state_dict)
del load
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
_named_members
# 通用的查找函数,可查找model的参数buffer本身等等
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
# 有对named_modules的调用,该函数返回所有的模块
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
# 对所有的模块迭代
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
parameters
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>> print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
_parameters
只返回当前model的参数不包含子model, parameters()
迭代返回所有
for p in test_net.parameters():
print(p)
'''
Parameter containing:
tensor([[ 0.2101, -0.1331],
[-0.4235, -0.1817],
[ 0.4881, 0.2107]], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([-0.3277, 0.2989, -0.3436], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([[-0.3632, -0.0776, 0.2335],
[-0.2908, -0.5474, -0.3428],
[-0.1266, -0.5220, 0.2134],
[ 0.1425, 0.1375, -0.3167]], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([-0.4652, 0.5581, -0.0013, -0.1475], dtype=torch.float64,
requires_grad=True)
Parameter containing:
tensor([1., 1., 1., 1.], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], dtype=torch.float64, requires_grad=True)
'''
named_parameters
def named_parameters(
self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
remove_duplicate (bool, optional): whether to remove the duplicated
parameters in the result. Defaults to True.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
"""
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
yield from gen
注:parameters()
走的named_parameters
,named_parameters
走的_named_members
,传入的是lambda module: module._parameters.items()
,返回的是model
自身的参数,_named_members
又对传入的模块的所有子模块迭代,所以parameters()
最终能得到所有模块的参数。下述buffer
同理
buffers
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
r"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>> print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
"""
for _, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
yield from gen
children
def children(self) -> Iterator['Module']:
r"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
"""
for name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
"""
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
for p in test_net.named_children():
print(p)
'''
('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
'''
print(test_net._modules)
'''
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
'''
modules
def modules(self) -> Iterator['Module']:
r"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for _, module in self.named_modules():
yield module
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
yield prefix, self # 首先返回自身, 再对子模块遍历
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
yield from module.named_modules(memo, submodule_prefix, remove_duplicate)
for p in test_net.named_modules():
print(p)
# 返回+自身
print(test_net._modules)
# 只是返回子模块
'''
('', Net(
(linear1): Linear(in_features=2, out_features=3, bias=True)
(linear2): Linear(in_features=3, out_features=4, bias=True)
(batch_norm): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
))
('linear1', Linear(in_features=2, out_features=3, bias=True))
('linear2', Linear(in_features=3, out_features=4, bias=True))
('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
'''
还不快抢沙发