Multi-Label Classification for Transformers have similar representations internally

28 Views Asked by At

I have a transformer model that needs to output three different labels.

def forward(self, src, target):
    src_mask = self.make_src_mask(src)
    target_mask = self.make_target_mask(target)
    enc_src = self.encoder(src, src_mask)
    out_1 = self.decoder(target, enc_src, src_mask, target_mask) # Plain
    out_2 = self.decoder(target, enc_src, src_mask, target_mask) # Mid
    out_3 = self.decoder(target, enc_src, src_mask, target_mask) # Dir
    out_4 = self.relu(out_3) # Dir
    flatten = torch.flatten(out_2, start_dim=1, end_dim=2)
    out_5 = self.fc_out_mid(flatten) # size [num_samples,num_frames]
    out_4 = self.fc_out_direction(out_4)
    out_4 = torch.nn.functional.log_softmax(out_4, dim=2)
    return out_1, out_5, out_4, out_2, out_3

Loss is computed on out_1, out_5 and out_4 separately and combined. But in essence I expect out_1, out_2 and out_3 to have the same shape but different values. When I print them post training I get different values on the training set i.e If I did this

    print(out_1) 
    print(out_2) 
    print(out_3)

I get different values for the training set (which is my expectation as the losses for each of these branches are different) but when I do the same on validation and test set I get exactly the same values for all of these. How is this possible? What am I doing incorrectly here?

0

There are 0 best solutions below