Backpropagation through concatenation of elements of a batch

121 Views Asked by At

I have two feature vectors V1(N, F1, 1) and V2(N, F2, 1). I want to concatenate them across dimension 1 to create a vector V3(N, F1+F2, 1) and apply self attention across elements of the batch, i.e across N. To that end I was thinking of splitting the batch and concatenating all elements of the batch to create a sequence of elements of batch size 1 and embedding dimension F1+F2 and sequence length N. So a new feature vector would look like V3'(1, N, F1+F2) (This is assuming we have flattened the previous vectors across the last dimension).

My question is, is this possible to do? Would there be issues with backpropagation of gradients because the elements of the batch are split apart?

Visual Representation

My thoughts are this should work, that gradients should be propagated without any issues but I'm a beginner and out of my depth here.

1

There are 1 best solutions below

0
lejlot On

It depends on the details of how you implement it, but in general concatenation ops in modern frameworks will propagate gradients in the way you would expect them to.

In particular torch.cat is fully differentiable