Let's talk about the conclusion first
model.state_dict()
It is a shallow copy, and the returned parameters will still change with network training.
Should be useddeepcopy(model.state_dict())
, or serialize the parameters to the hard disk in time.
Let me tell you a story. A few days ago, when doing cross-verification training for a model, the parameters of each set of cross-verification model were saved through model.state_dict(), and then the model with the best accuracy was selected according to the effect and loaded back. As a result, each time was the last model. From the address, each saved state_dict() has a different address, but it was further found that the addresses of each model parameter under state_dict() are shared, and I used the in-place method to reset the model parameters, which led to the above problems.
Supplement: Understanding of state_dict in pytorch
In PyTorch, state_dict is a Python dictionary object (in this ordered dictionary, key is the parameter name of each layer and value is the parameter of each layer), which contains the learnable parameters of the model (i.e. weights and deviations, and parameters of the bn layer). The optimizer object() also has state_dict, which contains information about the state of the optimizer and the hyperparameters used.
Actually, after reading the output of the following code, you should understand it
import torch import as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = (16 * 5 * 5, 120) self.fc2 = (120, 84) self.fc3 = (84, 10) def forward(self, x): x = ((self.conv1(x))) x = ((self.conv2(x))) x = (-1, 16 * 5 * 5) x = (self.fc1(x)) x = (self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = ((), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
The output is as follows:
Model's state_dict: ([6, 3, 5, 5]) ([6]) ([16, 6, 5, 5]) ([16]) ([120, 400]) ([120]) ([84, 120]) ([84]) ([10, 84]) ([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
I am a novice who has just come into contact with deep learning in the West. I hope the big guy can point out my shortcomings for me. This blog is just for my own notes! ! ! !
Supplement: pytorch reports an error when saving the model***object has no attribute 'state_dict'
Define a class BaseNet and instantiate the class:
net=BaseNet()
Error when saving net object has no attribute 'state_dict'
(net.state_dict(), models_dir)
The reason is that when defining a class, it is not an inheritance class, for example:
class BaseNet(object): def __init__(self):
Change the class definition to
class BaseNet(): def __init__(self): super(BaseNet, self).__init__()
The above is personal experience. I hope you can give you a reference and I hope you can support me more. If there are any mistakes or no complete considerations, I would like to give you advice.