I am writing ElasticWeightConsolidation method and for that I need to compute Fisher matrix. As I understood Fisher Matrix is just Hessian of likelihood by weights of neural network. There is good function as torch.autograd.functional.hessian(func, inputs, create_graph=False, strict=False)
So I want to compute hessian(loss,weights) where loss = torch.nn.CrossEntropyLoss(). I also prepared weights of the network so that it its long 1D tensor, to have a possibility simply take diagonal elements of hessian like that:
def flat_param(model_param = yann_lecun.parameters()):
ans_data = []
ans_data = torch.tensor(ans_data, requires_grad=True)
ans_data = ans_data.to(device)
for p in model_param:
temp_data = p.data.flatten()
ans_data = torch.cat((ans_data,temp_data))
return ans_data
ans = flat_param(yann_lecun.parameters())
then I tried so: hessian(loss, inputs = ans) but problem is that loss takes also targets, but I don't want to compute hessian of them. The task is mnist classification so that targets are integers 0,...,9
and if I add y_train to the parameters like that hessian(loss,inputs = (ans,y_train_01)
It is crashing with words "can't take gradient from integer". I tried also to make y_train_01.requires_grad = False but it didn't help. I understood that loss also depends on y_train_01, but is there any way to determine that targets are constants in my case?
You can create a new 'wrapper' function where the targets are fixed:
And then call:
The second order partial derivatives from this will be equivalent to the analogous ones of the full Hessian product at the location
(features_vals, fixed_targets).