Given an linear system Ax=b, x = [x1; x2], where A,b are given and x1 is also given. A is square matrix. I want to compute the gradient of x1 in terms of A,b, i.e. dx1/dA, dx1/db. It seems that torch.autograd.backward() can only compute the gradient by solving the whole linear system. But I want to make it more efficient by not solving x1 again since it is already given. Is it feasible? If so, how to implement it using pytorch?
For now, I only know how to get the gradient by solving the whole system:
#[A1 A2; A3 A4] * [x1; x2] = [b1; b2]
m1 = 10
n1 = 10
m3 = 5
n3 = n1
m2 = m1
n2 = m1 + m3 - n1
m4 = m3
n4 = n2
A = torch.rand((m1+m3,n1+n2)).clone().detach().requires_grad_(True)
A1 = A[:m1,:n1]
A2 = A[:m1,n1:]
A3 = A[m1:,:n1]
A4 = A[m1:,n1:]
# x1 is already given
x1 = torch.ones((n1,1)).clone().detach().requires_grad_(True)
# x2 needs to be solved
x2_gt = torch.rand((n2,1))
b1 = (A1 @ x1 + A2 @ x2_gt).detach().requires_grad_(True)
b2 = (A3 @ x1 + A4 @ x2_gt).detach().requires_grad_(True)
b = torch.vstack((b1,b2))
x = torch.linalg.solve(A,b)
x.backward(torch.ones(m1+m3,1))