SoFunction
Updated on 2025-03-01

Operation of pytorch fixed BN layer parameters

background:

Based on PyTorch's model, if you want to fix the main branch parameters and train only the sub-branches, it is found that the results of the same test data in different epochs output through the main branch are different.

reason:

Running_mean and running_var in the main branch BN layer are not fixed.

Solution:

Set the BN layer state that needs to be fixed is set to eval.

Question example:

Environment: torch: 1.7.0

# -*- coding:utf-8 -*-
import torch
import  as nn
import  as F

class Net():

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = (16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = (120, 84)
        self.fc3 = (84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d((self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d((self.bn2(self.conv2(x))), 2)
        x = (-1, self.num_flat_features(x))
        x = (self.fc1(x))
        x = (self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = ()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    .manual_seed(5)
    test_data = (1, 1, 32, 32)
    train_data = (5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, assuming each epoch is iterated only once        ()
        pre = net(train_data)
        # Calculate loss and parameter updates, etc.        # ....

        # test phase
        ()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

Running results:

-------parameters requires grad info--------
: True
: True
: True
: True
: True
: True
: True
: True
: True
: True
: True
: True
: True
: True
-------parameters requires grad info--------
: False
: False
: False
: False
: False
: False
: False
: False
: False
: False
: False
: False
: False
: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

You can see:

net.requires_grad_(False) has set the parameters in the network to a state where gradient updates are not required, but the same test data test_data has different results after forward in different epochs.

Calling print_net_state_dict can see that the parameters running_mean and running_var in the BN layer are not in the optimizeable parameters.



bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

But during the forward process of training pahse, these two parameters are updated. As a result, the same test data showed different results in the entire network in freeze situation

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

Therefore, when training phase, the eval state is explicitly set to the BN layer:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    .manual_seed(5)
    test_data = (1, 1, 32, 32)
    train_data = (5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, assuming each epoch is iterated only once        ()
        net.()
        net.()
        pre = net(train_data)
        # Calculate loss and parameter updates, etc.        # ....

        # test phase
        ()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

You can see that the results are normal:

epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

Supplement: Detailed explanation and application of BN layer parameters of pytorch--(1,2,3)(1,2)?

Detailed explanation of BN layer parameters (1, 2)

Generally speaking, models in pytorch are inherited classes, and they all have an attribute training to specify whether the training state is the training state. Whether the training state or not will affect whether the parameters of certain layers are fixed, such as the BN layer (the mean and variance tested for the BN layer is the mean and variance of all batches during training through statistical training) or the Dropout layer (all neurons are activated during the test of the Dropout layer). Usually, () is used to specify that the current model is the training state, and () is used to specify that the current model is the test state.

At the same time, there are several parameters in BN's API that need to be paid attention to. One is that affine specifies whether affine is needed, and the other is that track_running_stats specifies whether to track the statistical characteristics of the current batch. These three parameters are prone to problems: training, affine, track_running_stats.

The affine specifies whether affine is needed, that is, whether the fourth of the above equation is needed. If affine=False then γ=1, β=0 \gamma=1, \beta=0 γ=1, β=0, and cannot be learned to be updated. It is usually set to affine=True. (Here is a learnable parameter)

training and track_running_stats, track_running_stats=True means tracking the statistical characteristics of batches during the entire training process to obtain the variance and mean, rather than relying solely on the statistical characteristics of the batches input (meaning that the new batch depends on the mean and variance of the previous batches. Here the momentum parameter is used, referring to the exponential moving average algorithm EMA). On the contrary, if track_running_stats=False then it just calculates the mean and variance in the statistical characteristics of the currently input batch. When in the inference stage, if track_running_stats=False, if batch_size is relatively small, its statistical characteristics will have a large deviation from the global statistical characteristics, which may lead to poor results.

Application tips: (1, 2)

Usually pytorch uses optimizer.zero_grad() to clear the gradient accumulated by previous batches. Because the gradient calculated by Variable in pytorch will be accumulated, each batch must be cleared again. The original method is as follows:

Question: Parameter non_blocking, and the overall framework of pytorch??

Code (1)

for index,data,target in enumerate(dataloader):
    data = (non_blocking=True)
    target = torch.from_numpy((target)).float().cuda(non_blocking = Trye)
    output = model(data)
    loss = criterion(output,target)
    
    #Clear the gradient    optimizer.zero_grad()
    ()
    ()

Here, in order to imitate minibacth, we do not clear 0 each time the batch is not cleared, and then clear 0 to a certain number of times, and then update the weight:

for index, data, target in enumerate(dataloader):
    #If it is not Tensor, you usually need to use torch.from_numpy()    data = (non_blocking = True)
    target = torch.from_numpy((target)).float().cuda(non_blocking = True)
    output = model(data)
    loss = criterion(data, target)
    ()
    if index%accumulation == 0:
        #Update weights with accumulated gradients        ()
        #Clear the gradient        optimizer.zero_grad()

Although the gradient here is equivalent to the original accumulation times, it has almost no effect on BN during the forward propagation process, because the forward BN is still just the mean and variance of a batch. At this time, you can use the momentum parameter of BN in pytorch. The default is 0.1. The BN parameter is as follows, which is the exponential moving average

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

The above is personal experience. I hope you can give you a reference and I hope you can support me more.