How does bp work in batched matmul scene? assume we have a linear, the shape of x is [b,m,k], shape of W is [k,n]:
y = xW
then:
\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}
so what had happened in backward of W?
- We save the full $x^T$ tensor
[b,k,m], and execute the batched matmul[b,k,m] x [b,m,n], finally reduce the result to[k,n]to get the gradient. - We directly save the averaged $x^T$ tensor
[k,m], and execute the batched matmul[k,m] x [m,n]to get the gradient.
I have read PyTorch's source code, but I have problems to understand it.