SoFunction
Updated on 2025-03-02

PyTorch method to implement model pruning

Guidelines Overview

In this article, I will introduce you to how to implement model pruning in PyTorch. Pruning is a technique that optimizes the model and can help reduce the size and computational volume of the model while maintaining the accuracy of the model. I will provide you with a detailed step guide and guide you on how to use the appropriate PyTorch code in each step.

Overall process

Here is the overall process of implementing PyTorch pruning, we will follow these steps to proceed step by step:

step operate
1. Loading pretrained models
2. Defining pruning algorithm
3. Perform a pruning operation
4. Retrain and fine-tune the model
5. Evaluate the performance of the pruning model

Steps detailed explanation

Step 1: Load the pretrained model

First, we need to load a pretrained model as our base model. Here, we take ResNet18 as an example.

import torch
import  as models

# Load the pre-trained ResNet18 modelmodel = models.resnet18(pretrained=True)

Step 2: Define the pruning algorithm

Next, we need to define a pruning algorithm. Here we take Global Magnitude Pruning as an example.

from  import global_unstructured

# Define pruning ratiopruning_rate = 0.5

# Prune the fully connected layer of the modeldef prune_model(model, pruning_rate):
    for name, module in model.named_modules():
        if isinstance(module, ):
            global_unstructured(module, pruning_dim=0, amount=pruning_rate)

Step 3: Perform the pruning operation

Now we can perform pruning operations and see the pruning model structure.

prune_model(model, pruning_rate)

# Check the pruned model structureprint(model)

Step 4: Retrain and fine-tune the model

The pruned model needs to be retrained and fine-tuned to ensure the accuracy and performance of the model.

# Define loss functions and optimizerscriterion = ()
optimizer = ((), lr=0.001, momentum=0.9)

# Retrain and fine-tune the model# Omit training code

Step 5: Evaluate the model performance after pruning

Finally, we need to evaluate the pruning model to compare the performance differences before and after pruning.

# Evaluate the pruning model# Omit the evaluation code

Supplement: There are three ways to prune in PyTorch:

  • Local pruning
  • Global pruning
  • Custom pruning

Local pruning

Local pruning experiment, assuming pruning weights in the first convolutional layer of the model

model_1 = LeNet()
module = model_1.conv1
# Before pruningprint(list(module.named_parameters()))
print(list(module.named_buffers()))
prune.random_unstructured(module, name="weight", amount=0.3)
# After pruningprint(list(module.named_parameters()))
print(list(module.named_buffers()))

Running results

## Before pruning
[('weight', Parameter containing:
tensor([[[[ 0.1729, -0.0109, -0.1399],
          [ 0.1019,  0.1883,  0.0054],
          [-0.0790, -0.1790, -0.0792]]],
        
        ...

        [[[ 0.2465,  0.2114,  0.3208],
          [-0.2067, -0.2097, -0.0431],
          [ 0.3005, -0.2022,  0.1341]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1437,  0.0605,  0.1427, -0.3111, -0.2476,  0.1901],
       requires_grad=True))]
[]

## After pruning
[('bias', Parameter containing:
tensor([-0.1437,  0.0605,  0.1427, -0.3111, -0.2476,  0.1901],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1729, -0.0109, -0.1399],
          [ 0.1019,  0.1883,  0.0054],
          [-0.0790, -0.1790, -0.0792]]],

        ...

        [[[ 0.2465,  0.2114,  0.3208],
          [-0.2067, -0.2097, -0.0431],
          [ 0.3005, -0.2022,  0.1341]]]], requires_grad=True))]

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[0., 1., 1.],
          [1., 0., 1.],
          [1., 0., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [0., 1., 0.]]],


        [[[0., 0., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]]]))]

After the model undergoes pruning operation, the original weight matrix weight parameter disappears and becomes weight_orig. And print it as an empty list before pruningmodule.named_buffers(), at this time, there is a weight_mask parameter. After the pruning operation, the original parameters are stored in weight_orig, and the corresponding pruning matrix is ​​stored in weight_mask. The weight_mask is regarded as a mask tensor, and the result of multiplying with weight_orig is stored in weight.

Global pruning

Local pruning can only be pruned in some network modules. A broader pruning strategy is to use global pruning, such as pruning 20% ​​of the weight parameters from the perspective of the overall network, rather than pruning 20% ​​of the weight parameters on each layer. After using global pruning, the percentages of different layers are cut differently.

model_2 = LeNet().to(device=device)

# First print the status dictionary of the initialization modelprint(model_2.state_dict().keys())

# Build parameter sets, determine which layers, which parameter sets participate in pruningparameters_to_prune = (
            (model_2.conv1, 'weight'),
            (model_2.conv2, 'weight'),
            (model_2.fc1, 'weight'),
            (model_2.fc2, 'weight'),
            (model_2.fc3, 'weight'))
# Call global_unstructured in prune to perform pruning operation, here pruning is performed for 20% of the parameter quantity in the overall modelprune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

# Finally print the status dictionary of the pruned modelprint(model_2.state_dict().keys())

Output result

odict_keys(['', 'conv1.weight_orig', 'conv1.weight_mask', '', 'conv2.weight_orig', 'conv2.weight_mask', '', 'fc1.weight_orig', 'fc1.weight_mask', '', 'fc2.weight_orig', 'fc2.weight_mask', '', 'fc3.weight_orig', 'fc3.weight_mask'])

When using the global pruning strategy (assuming that 20% of the proportional parameters participate in pruning), only 20% of the overall parameter amount of the model is pruned. The specific situation for each layer is determined by the specific parameter distribution of the model.

Custom pruning

Custom pruning can customize a subclass to implement specific pruning logic, such as spaced pruning of weight matrix

class my_pruning_method():
    PRUNING_TYPE = "unstructured"
    
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        (-1)[::2] = 0
        return mask
    
def my_unstructured_pruning(module, name):
    my_pruning_method.apply(module, name)
    return module

model_3 = LeNet()
print(model_3)

View network structure before pruning

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

Use custom pruning method to prune the local module FC3

my_unstructured_pruning(model.fc3, name="bias")
print(model.fc3.bias_mask)

Output result

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

The final pruning effect is consistent with the logic of the implementation.

Summarize

Through the above step guide and code examples, I believe you can learn how to implement model pruning in PyTorch. Pruning is an effective model optimization technology that can help you build more efficient and accurate deep learning models.

This is the end of this article about PyTorch's method of implementing model pruning. For more related PyTorch pruning content, please search for my previous articles or continue browsing the related articles below. I hope everyone will support me in the future!