How can I prune efficientnet parameters via pytorch?

296 Views Asked by At

I am training a model with efficientnet pytorch and to reduce overfitting, I want to prune some of the parameters.

My model is implemented as follows:


import torchvision.models as models
import torch.nn as nn

model = models.efficientnet_b0(pretrained=True)

for params in model.parameters():
    params.requires_grad = True

model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)

This tutorial for pruning a pytorch model suggests isolating a module and then pruning it as follows:

module = model.conv1 prune.random_unstructured(module, name="weight", amount=0.3)

However, there is no module named conv1 in the efficientnet model. I then listed the modules of the efficientnet model and got the following:

  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)

...

However, I do not know which of these are the module names, nor do I know how to isolate the module. How can I do so?

Am I on the right path for pruning my original model? Is there a different approach I should take? Thank you for the clarification and help.

2

There are 2 best solutions below

3
Laassairi Abdellah On

To start with let's upload the model :

import torchvision.models as models
import torch.nn as nn
from torch.nn.utils import prune 

model = models.efficientnet_b0(pretrained=True)
num_classes=10
for params in model.parameters():
    params.requires_grad = True

model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)

Clear blueprints of the model :

Printing the model directly isn't a very clear way to get an overview of its different components. I personally prefer using torchinfo

eg :

from torchinfo import summary
summary(model, input_size=(1,3, 28, 28))

Will give you the following output :

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
EfficientNet                                            [1, 10]                   --
├─Sequential: 1-1                                       [1, 1280, 1, 1]           --
│    └─Conv2dNormActivation: 2-1                        [1, 32, 14, 14]           --
│    │    └─Conv2d: 3-1                                 [1, 32, 14, 14]           864
│    │    └─BatchNorm2d: 3-2                            [1, 32, 14, 14]           64
│    │    └─SiLU: 3-3                                   [1, 32, 14, 14]           --
│    └─Sequential: 2-2                                  [1, 16, 14, 14]           --
│    │    └─MBConv: 3-4                                 [1, 16, 14, 14]           1,448
│    └─Sequential: 2-3                                  [1, 24, 7, 7]             --
│    │    └─MBConv: 3-5                                 [1, 24, 7, 7]             6,004
│    │    └─MBConv: 3-6                                 [1, 24, 7, 7]             10,710
│    └─Sequential: 2-4                                  [1, 40, 4, 4]             --
│    │    └─MBConv: 3-7                                 [1, 40, 4, 4]             15,350
│    │    └─MBConv: 3-8                                 [1, 40, 4, 4]             31,290
│    └─Sequential: 2-5                                  [1, 80, 2, 2]             --
│    │    └─MBConv: 3-9                                 [1, 80, 2, 2]             37,130
│    │    └─MBConv: 3-10                                [1, 80, 2, 2]             102,900
│    │    └─MBConv: 3-11                                [1, 80, 2, 2]             102,900
│    └─Sequential: 2-6                                  [1, 112, 2, 2]            --
│    │    └─MBConv: 3-12                                [1, 112, 2, 2]            126,004
│    │    └─MBConv: 3-13                                [1, 112, 2, 2]            208,572
│    │    └─MBConv: 3-14                                [1, 112, 2, 2]            208,572
│    └─Sequential: 2-7                                  [1, 192, 1, 1]            --
│    │    └─MBConv: 3-15                                [1, 192, 1, 1]            262,492
│    │    └─MBConv: 3-16                                [1, 192, 1, 1]            587,952
│    │    └─MBConv: 3-17                                [1, 192, 1, 1]            587,952
│    │    └─MBConv: 3-18                                [1, 192, 1, 1]            587,952
│    └─Sequential: 2-8                                  [1, 320, 1, 1]            --
│    │    └─MBConv: 3-19                                [1, 320, 1, 1]            717,232
│    └─Conv2dNormActivation: 2-9                        [1, 1280, 1, 1]           --
│    │    └─Conv2d: 3-20                                [1, 1280, 1, 1]           409,600
│    │    └─BatchNorm2d: 3-21                           [1, 1280, 1, 1]           2,560
│    │    └─SiLU: 3-22                                  [1, 1280, 1, 1]           --
├─AdaptiveAvgPool2d: 1-2                                [1, 1280, 1, 1]           --
├─Sequential: 1-3                                       [1, 10]                   --
│    └─Dropout: 2-10                                    [1, 1280]                 --
│    └─Linear: 2-11                                     [1, 10]                   12,810
=========================================================================================================
Total params: 4,020,358
Trainable params: 4,020,358
Non-trainable params: 0
Total mult-adds (M): 8.11
=========================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 1.97
Params size (MB): 16.08
Estimated Total Size (MB): 18.06
=========================================================================================================

We can see that our model is composed of 2 main modules :

  • Sequential: 1-1 ( this is the feature extractor which in turn is composed of multiple sub modules)
  • Sequential: 1-3 (this is your classifier)

To get the primary modules names you can use :

set([i.split(".")[0] for i in model.state_dict().keys()])

Which will output :

{'classifier', 'features'}

Prunning :

For prunning we're mostly interested in the weights and biases of the individual layers, not the abstracted submodules :

for name, param in model.named_parameters():
    print(name)

Will output :

features.0.0.weight
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight
features.1.0.block.0.1.bias
...

features.8.1.weight
features.8.1.bias
classifier.1.weight
classifier.1.bias

Now you can prune each layer you want as specified in the documentation :

module_to_prune =  model.features[0][0]
prune.random_unstructured(module_to_prune, name="weight", amount=0.3)

Let's check if the specified weights got pruned, and if there's a new _orig in the named_parameters.

for name, param in model.named_parameters():
    print(name)

output :

features.0.0.weight_orig
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight
0
Sun Jiaojiao On

Thanks for Laassairi's answer, and I think it can be done automatically.

A solution:

parameters_to_prune = []

for name, param in model.named_parameters():
    if name.endswith('weight'):
        name = name.replace('.weight', '')
        name = re.sub(r'\.([0-9]+)', r'[\1]', name)
        code = f"parameters_to_prune.append((model.{name}, 'weight'))"
        exec(code)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.3,
)