Use indexer to cast batched input tensor into different Sequentials

45 Views Asked by At

I'm having troubles in the following scenario:

Suppose we have an indexer tensor (or vector for that case) that, for example, contains these values: tensor([0,1,1,0]). Also, let's suppose I have a batch of equal length (4) of input tensors, each of length 64: tensor([[foo],[bar],[tee],[aaa]]).

Based on the numbers specified by the indexer tensor, I need to collect the corresponding input tensors. That is, for index 0 we have tensor([foo], [aaa]) and for index 1 we have tensor([bar],[tee]).

This is because I now have to feed each "group" of inputs to a different Sequential net whose structure is not relevant here. The point is that each of these two Sequentials will give outputs output_0 for index 0's inputs, and output_1 for index 1's inputs. The outputs are as such: tensor([output_foo], [output_aaa]) and tensor([output_bar],[output_tee]), respectively.

The last step is, using the indexer tensor from the beginning, join all the output tensors back into one, that is, tensor([output_foo], [output_bar],[output_tee], [output_aaa]), while preserving gradient.

I tried doing this with for-loops, that is, looping first for index 0, then for 1, and inside each iteration I collected the corresponding inputs, and got the corresponding outputs. I then had to mix it all up with yet another for loop. But the thing is, although the successive output tensors from each iteration have each their own gradient function, the final tensor doesn't. That means I'm crashing gradient here, and I suspect it is beacuse of the use of loops.

My question is, is there a way to achieve this without using for-loops?. I have the idea that using the indexer tensor and maybe some boolean masks would do the job in a sort-of matricial manner, but I still can't visualize what it would look like in real code.

Here's a minimal example:

B = 5   # batch size
F = 64  # feature dimension

outputs = torch.zeros((B), requires_grad=True)

indexer = torch.argmax(masks, dim=1) # (B,) int
for idx in indexer.unique():
    mask = (indexer == idx) # (B,) bool
    curr_inputs = torch.masked_select(inputs, mask.unsqueeze(-1).repeat(1,F)).view(-1,F) # (b, F), only the entries for idx

    # Apply the appropiate Sequential to each input according to its original index
    curr_outputs = apply_to_sequential_by_idx(curr_inputs, idx)

    # Join back together the outputs to get back to shape (B,)
    j=0
    for i, index in enumerate(indexer):
        if idx == index:
            outputs.data[i].copy_(curr_outputs.data[j])
            j+=1

return outputs # Gradient is lost here (no grad_fn)
0

There are 0 best solutions below