Solve the state_dict of copy problem of pytorch

  • 2021-09-16 07:19:25
  • OfStack

Let's start with the conclusion

model.state_dict() It is a shallow copy, and the returned parameters will still change with the training of the network.

You should use the deepcopy(model.state_dict()) Or serialize the parameters to the hard disk in time.

Tell the story again, A few days ago, when doing cross-validation training of a model, The parameters of each set of cross-validation models are saved by model. state_dict (), After that, the model load with the best accuracy is selected according to the effect, and the result is the last model every time. From the address point of view, every saved state_dict () has different addresses, but it is found that the addresses of various model parameters under state_dict () are shared, and I use in-place to reset the model parameters, which leads to the above problems.

Supplement: Understanding of state_dict in pytorch

In PyTorch, state_dict is an Python dictionary object (in this ordered dictionary, key is the parameter name of each layer, value is the parameter of each layer), and the optimizer object (torch. optim) containing the learnable parameters of the model (i.e. weights and deviations, as well as the parameters of the bn layer) also has state_dict, which contains information about the optimizer state and the hyperparameters used.

In fact, after reading the output of the following code, I should understand it


import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

The output is as follows:


Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

I am just contact with the depth of the West of the small white 1, I hope the big boss can point out my shortcomings for me, this blog is only his own notes! ! ! !

Added: pytorch Save Model Error ***object has no attribute 'state_dict'

A class BaseNet is defined and instantiated:


net=BaseNet()

Save net and report error object has no attribute 'state_dict'


torch.save(net.state_dict(), models_dir)

The reason is that classes are defined without inheriting the nn. Module classes, such as:


class BaseNet(object):
  def __init__(self):

Change the class definition to


class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

Related articles: