PyTorch: Calculate Hessian only for a subset of parameters?

619 Views Asked by At

I am writing ElasticWeightConsolidation method and for that I need to compute Fisher matrix. As I understood Fisher Matrix is just Hessian of likelihood by weights of neural network. There is good function as torch.autograd.functional.hessian(func, inputs, create_graph=False, strict=False)

So I want to compute hessian(loss,weights) where loss = torch.nn.CrossEntropyLoss(). I also prepared weights of the network so that it its long 1D tensor, to have a possibility simply take diagonal elements of hessian like that:

def flat_param(model_param = yann_lecun.parameters()):
  ans_data = []
  ans_data = torch.tensor(ans_data, requires_grad=True)
  ans_data = ans_data.to(device)
  for p in model_param:
    temp_data = p.data.flatten()
    ans_data = torch.cat((ans_data,temp_data))
  return ans_data

ans = flat_param(yann_lecun.parameters())

then I tried so: hessian(loss, inputs = ans) but problem is that loss takes also targets, but I don't want to compute hessian of them. The task is mnist classification so that targets are integers 0,...,9 and if I add y_train to the parameters like that hessian(loss,inputs = (ans,y_train_01)

It is crashing with words "can't take gradient from integer". I tried also to make y_train_01.requires_grad = False but it didn't help. I understood that loss also depends on y_train_01, but is there any way to determine that targets are constants in my case?

1

There are 1 best solutions below

0
iacob On

You can create a new 'wrapper' function where the targets are fixed:

def original_model(features, targets):
    ...
    return ...

def feature_differentiable_model(features):
    fixed_targets = ...
    return original_model(features, fixed_targets)

And then call:

hessian(feature_differentiable_model, features_vals)

The second order partial derivatives from this will be equivalent to the analogous ones of the full Hessian product at the location (features_vals, fixed_targets).