Saving and loading models for inference in PyTorch

There are two approaches for saving and loading models for inference in PyTorch.

  • The first is saving and loading the state_dict
  • and the second is saving and loading the entire model.
 def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
        r"""Returns a dictionary containing references to the whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are
        included. Keys are corresponding parameter and buffer names.
        Parameters and buffers set to ``None`` are not included.

        """

        # TODO: Remove `args` and the parsing logic when BC allows.
        if len(args) > 0:
            if destination is None:
                destination = args[0]
            if len(args) > 1 and prefix == '':
                prefix = args[1]
            if len(args) > 2 and keep_vars is False:
                keep_vars = args[2]
            # DeprecationWarning is ignored by default
            warnings.warn(
                "Positional args are being deprecated, use kwargs instead. Refer to "
                "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
                " for details.")

        if destination is None:
            destination = OrderedDict()
            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)
        if hasattr(destination, "_metadata"):
            destination._metadata[prefix[:-1]] = local_metadata

        for hook in self._state_dict_pre_hooks.values():
            hook(self, prefix, keep_vars)
        self._save_to_state_dict(destination, prefix, keep_vars)
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
        for hook in self._state_dict_hooks.values():
            hook_result = hook(self, destination, prefix, local_metadata)
            if hook_result is not None:
                destination = hook_result
        return destination

Steps

  • 构建Dataset
  • Define and initialize the neural network

    • class Net(nn.Module):
          def __init__(self):
              super(Net, self).__init__()
              self.conv1 = nn.Conv2d(3, 6, 5)
              self.pool = nn.MaxPool2d(2, 2)
              self.conv2 = nn.Conv2d(6, 16, 5)
              self.fc1 = nn.Linear(16 * 5 * 5, 120)
              self.fc2 = nn.Linear(120, 84)
              self.fc3 = nn.Linear(84, 10)
      
          def forward(self, x):
              x = self.pool(F.relu(self.conv1(x)))
              x = self.pool(F.relu(self.conv2(x)))
              x = x.view(-1, 16 * 5 * 5)
              x = F.relu(self.fc1(x))
              x = F.relu(self.fc2(x))
              x = self.fc3(x)
              return x
      
      net = Net()
      print(net)
  • Initialize the optimizer

    • # 第一个参数传入的是parameters()函数:会自动遍历所有子模块
      optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  • Save and load the model via state_dict

    • PATH = "state_dict_model.pt"
      
      # Save 获取该模型全部的参数和Buffer量,但是没有保存模型的结构
      torch.save(net.state_dict(), PATH)
      
      # Load 先创建网络结构
      model = Net()
      # 再加载权重
      model.load_state_dict(torch.load(PATH))
      # 仅针对inference,因为没有保存优化器部分
      model.eval()
  • Save and load entire model

    • # Specify a path
      PATH = "entire_model.pt"
      
      # Save
      torch.save(net, PATH)
      
      # Load 不用事先定义结构了
      model = torch.load(PATH)
      model.eval()

Saving and loading a general checkpoint in PyTorch

Saving and loading a general checkpoint model for ==inference or resuming training== can be helpful for picking up where you last left off. When saving a general checkpoint, you must save more than just the model’s state_dict. ==It is important to also save the optimizer’s state_dict==, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layers, and more, based on your own algorithm.

  • 4. Save the general checkpoint

    • # Additional information
      EPOCH = 5
      PATH = "model.pt"
      LOSS = 0.4
      
      torch.save({
                  'epoch': EPOCH,
                  'model_state_dict': net.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': LOSS,
                  }, PATH)
  • 5. Load the general checkpoint

    • model = Net()
      optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
      
      checkpoint = torch.load(PATH)
      # 模型参数
      model.load_state_dict(checkpoint['model_state_dict'])
      # 优化器参数
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      epoch = checkpoint['epoch']
      loss = checkpoint['loss']
      
      model.eval()
      # - or -
      model.train()

Saving and loading multiple models in one file using PyTorch

  • 2. Define and initialize the neural network

    • class Net(nn.Module):
          def __init__(self):
              super(Net, self).__init__()
              self.conv1 = nn.Conv2d(3, 6, 5)
              self.pool = nn.MaxPool2d(2, 2)
              self.conv2 = nn.Conv2d(6, 16, 5)
              self.fc1 = nn.Linear(16 * 5 * 5, 120)
              self.fc2 = nn.Linear(120, 84)
              self.fc3 = nn.Linear(84, 10)
      
          def forward(self, x):
              x = self.pool(F.relu(self.conv1(x)))
              x = self.pool(F.relu(self.conv2(x)))
              x = x.view(-1, 16 * 5 * 5)
              x = F.relu(self.fc1(x))
              x = F.relu(self.fc2(x))
              x = self.fc3(x)
              return x
      
      netA = Net()
      netB = Net()
  • 3. Initialize the optimizer

    • optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
      optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
  • 4. Save multiple models

    • # Specify a path to save to
      PATH = "model.pt"
      
      torch.save({
                  'modelA_state_dict': netA.state_dict(),
                  'modelB_state_dict': netB.state_dict(),
                  'optimizerA_state_dict': optimizerA.state_dict(),
                  'optimizerB_state_dict': optimizerB.state_dict(),
                  }, PATH)
  • 5. Load multiple models

    • modelA = Net()
      modelB = Net()
      optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
      optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
      
      checkpoint = torch.load(PATH)
      modelA.load_state_dict(checkpoint['modelA_state_dict'])
      modelB.load_state_dict(checkpoint['modelB_state_dict'])
      optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
      optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
      
      modelA.eval()
      modelB.eval()
      # - or -
      modelA.train()
      modelB.train()

      image-20240422221433771.png
      image-20240422221656022.png


本文由 fmujie 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论

召唤看板娘