Pruning nn.Linear weights inplace causes unexpected error, requires slightly weird workarounds. Need explanation

118 Views Asked by At

This fails

import torch

def test1():  
  layer = nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test1()

with error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
      8     x = 5 - torch.sum(layer(torch.ones(90)))
      9     x.backward()
---> 10 test1()
     11 # and this works as well
     12 

2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]

This works

import torch

def test2():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  del x    #main change
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test2()

and this works as well

import torch
def test3():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  layer.weight = torch.nn.Parameter(layer.weight)   #main change
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test3()

I encountered this when trying to implement a paper on model pruning (Temporal Neuron Variance Pruning). I believe this has something to do with the autograd graph, but I have am not sure what exactly is going on. I've already seen the link on pruning and got my code working using the 3rd snippet. I am now trying to figure out why 1 and 2 did not work. Is there some explanation for why these almost identical code snippets work or fail?

Major points I'd like to figure out -

  1. what is TBackward0
  2. where is it defined
  3. where is the runtime error raised
  4. why is the compatibility with the old shape expected - especially when the grad has been modified correctly (I am assuming I have edited the tensors correctly because cases 2, 3 work)
  5. can I change something else (other than the 2 working cases) to make this work ?
1

There are 1 best solutions below

2
Ashwath S On

Like you guessed, the issue is with the computational graph that gets created when you do backpropagation.

Let me explain the above point:

When you initialize a tensor in pytorch, it usually signals that the operations you perform on them should be tracked. When you do a forward pass, the functions for backward prop are set up and the graph is set.

In case 2, you are deleting the tensor and hence the entire process is reset -- the computation graph is reset. In case 3, you are clearly resetting the parameters.

The output tensor and the model parameters are connected to the graph.

If you want to clearly visualize where the TBackward0 function is, use torchviz to visualize the computational graph.