I am training a model with efficientnet pytorch and to reduce overfitting, I want to prune some of the parameters.
My model is implemented as follows:
import torchvision.models as models
import torch.nn as nn
model = models.efficientnet_b0(pretrained=True)
for params in model.parameters():
params.requires_grad = True
model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)
This tutorial for pruning a pytorch model suggests isolating a module and then pruning it as follows:
module = model.conv1
prune.random_unstructured(module, name="weight", amount=0.3)
However, there is no module named conv1 in the efficientnet model. I then listed the modules of the efficientnet model and got the following:
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(1): Sequential(
(0): MBConv(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
...
However, I do not know which of these are the module names, nor do I know how to isolate the module. How can I do so?
Am I on the right path for pruning my original model? Is there a different approach I should take? Thank you for the clarification and help.
To start with let's upload the model :
Clear blueprints of the model :
Printing the model directly isn't a very clear way to get an overview of its different components. I personally prefer using
torchinfoeg :
Will give you the following output :
We can see that our model is composed of 2 main modules :
To get the primary modules names you can use :
Which will output :
Prunning :
For prunning we're mostly interested in the weights and biases of the individual layers, not the abstracted submodules :
Will output :
Now you can prune each layer you want as specified in the documentation :
Let's check if the specified weights got pruned, and if there's a new _orig in the named_parameters.
output :