SoFunction
Updated on 2025-03-02

Solve the state_dict copy problem of pytorch

Let's talk about the conclusion first

model.state_dict()It is a shallow copy, and the returned parameters will still change with network training.

Should be useddeepcopy(model.state_dict()), or serialize the parameters to the hard disk in time.

Let me tell you a story. A few days ago, when doing cross-verification training for a model, the parameters of each set of cross-verification model were saved through model.state_dict(), and then the model with the best accuracy was selected according to the effect and loaded back. As a result, each time was the last model. From the address, each saved state_dict() has a different address, but it was further found that the addresses of each model parameter under state_dict() are shared, and I used the in-place method to reset the model parameters, which led to the above problems.

Supplement: Understanding of state_dict in pytorch

In PyTorch, state_dict is a Python dictionary object (in this ordered dictionary, key is the parameter name of each layer and value is the parameter of each layer), which contains the learnable parameters of the model (i.e. weights and deviations, and parameters of the bn layer). The optimizer object() also has state_dict, which contains information about the state of the optimizer and the hyperparameters used.

Actually, after reading the output of the following code, you should understand it

import torch
import  as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass():
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
     = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = (16 * 5 * 5, 120)
    self.fc2 = (120, 84)
    self.fc3 = (84, 10)
  def forward(self, x):
    x = ((self.conv1(x)))
    x = ((self.conv2(x)))
    x = (-1, 16 * 5 * 5)
    x = (self.fc1(x))
    x = (self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = ((), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

The output is as follows:

Model's state_dict:
  ([6, 3, 5, 5])
  ([6])
  ([16, 6, 5, 5])
  ([16])
  ([120, 400])
  ([120])
  ([84, 120])
  ([84])
  ([10, 84])
  ([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

I am a novice who has just come into contact with deep learning in the West. I hope the big guy can point out my shortcomings for me. This blog is just for my own notes! ! ! !

Supplement: pytorch reports an error when saving the model***object has no attribute 'state_dict'

Define a class BaseNet and instantiate the class:

net=BaseNet()

Error when saving net object has no attribute 'state_dict'

(net.state_dict(), models_dir)

The reason is that when defining a class, it is not an inheritance class, for example:

class BaseNet(object):
  def __init__(self):

Change the class definition to

class BaseNet():
  def __init__(self):
    super(BaseNet, self).__init__()

The above is personal experience. I hope you can give you a reference and I hope you can support me more. If there are any mistakes or no complete considerations, I would like to give you advice.