Suppose we have a function f whose gradient is slow to compute, and two functions g1 and g2 whose gradient is easy to compute. In pytorch, how can I calculate the gradients of z1 = g1(f(x)) and z2 = g2(f(x)) with respect to x, without having to calculate the gradient of f twice?
Example:
import torch
import time
def slow_fun(x):
A = x*torch.ones((1000,1000))
B = torch.matrix_exp(1j*A)
return torch.real(torch.trace(B))
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z1 = y**2
z2 = torch.sqrt(y)
start = time.time()
z1.backward(retain_graph = True)
end = time.time()
print("dz1/dx: ", x.grad)
print("duration: ", end-start, "\n")
x.grad = None
start = time.time()
z2.backward(retain_graph = True)
end = time.time()
print("dz2/dx: ", x.grad)
print("duration: ", end-start, "\n")
This prints
dz1/dx: tensor(-1673697.1250)
duration: 1.5571658611297607
dz2/dx: tensor(-13.2334)
duration: 1.3989012241363525
so calculating dz2/dx takes about as long as calculating dz1/dx.
The calculating of dz2/dx could be sped up if pytorch would store dy/dx during the calculation of dz1/dx, and then reuse that result during the calculation of dz2/dx.
Is there a mechanism built into pytorch to achieve such a behavior?
You can decouple the nested function using chain rule. However, there will be some differences due to numerical issues.