Saving and loading models for inference in PyTorch
There are two approaches for saving and loading models for inference in PyTorch.
- The first is saving and loading the
state_dict
- and the second is saving and loading the entire model.
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
warnings.warn(
"Positional args are being deprecated, use kwargs instead. Refer to "
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
" for details.")
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
for hook in self._state_dict_pre_hooks.values():
hook(self, prefix, keep_vars)
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
Steps
- 构建Dataset
Define and initialize the neural network
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
Initialize the optimizer
# 第一个参数传入的是parameters()函数:会自动遍历所有子模块 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Save and load the model via
state_dict
PATH = "state_dict_model.pt" # Save 获取该模型全部的参数和Buffer量,但是没有保存模型的结构 torch.save(net.state_dict(), PATH) # Load 先创建网络结构 model = Net() # 再加载权重 model.load_state_dict(torch.load(PATH)) # 仅针对inference,因为没有保存优化器部分 model.eval()
Save and load entire model
# Specify a path PATH = "entire_model.pt" # Save torch.save(net, PATH) # Load 不用事先定义结构了 model = torch.load(PATH) model.eval()
Saving and loading a general checkpoint in PyTorch
Saving and loading a general checkpoint model for ==inference or resuming training== can be helpful for picking up where you last left off. When saving a general checkpoint, you must save more than just the model’s state_dict. ==It is important to also save the optimizer’s state_dict==, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding
layers, and more, based on your own algorithm.
4. Save the general checkpoint
# Additional information EPOCH = 5 PATH = "model.pt" LOSS = 0.4 torch.save({ 'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS, }, PATH)
5. Load the general checkpoint
model = Net() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) checkpoint = torch.load(PATH) # 模型参数 model.load_state_dict(checkpoint['model_state_dict']) # 优化器参数 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
Saving and loading multiple models in one file using PyTorch
2. Define and initialize the neural network
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x netA = Net() netB = Net()
3. Initialize the optimizer
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9) optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
4. Save multiple models
# Specify a path to save to PATH = "model.pt" torch.save({ 'modelA_state_dict': netA.state_dict(), 'modelB_state_dict': netB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), }, PATH)
5. Load multiple models
modelA = Net() modelB = Net() optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9) optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
还不快抢沙发