About batched effect of matmul in PyTorch——When to reduce?

28 Views Asked by At

How does bp work in batched matmul scene? assume we have a linear, the shape of x is [b,m,k], shape of W is [k,n]:

    y = xW

then:

    \frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}

so what had happened in backward of W?

  1. We save the full $x^T$ tensor [b,k,m], and execute the batched matmul [b,k,m] x [b,m,n], finally reduce the result to [k,n] to get the gradient.
  2. We directly save the averaged $x^T$ tensor [k,m], and execute the batched matmul [k,m] x [m,n] to get the gradient.

I have read PyTorch's source code, but I have problems to understand it.

0

There are 0 best solutions below