I have a transformer model that needs to output three different labels.
def forward(self, src, target):
src_mask = self.make_src_mask(src)
target_mask = self.make_target_mask(target)
enc_src = self.encoder(src, src_mask)
out_1 = self.decoder(target, enc_src, src_mask, target_mask) # Plain
out_2 = self.decoder(target, enc_src, src_mask, target_mask) # Mid
out_3 = self.decoder(target, enc_src, src_mask, target_mask) # Dir
out_4 = self.relu(out_3) # Dir
flatten = torch.flatten(out_2, start_dim=1, end_dim=2)
out_5 = self.fc_out_mid(flatten) # size [num_samples,num_frames]
out_4 = self.fc_out_direction(out_4)
out_4 = torch.nn.functional.log_softmax(out_4, dim=2)
return out_1, out_5, out_4, out_2, out_3
Loss is computed on out_1, out_5 and out_4 separately and combined. But in essence I expect out_1, out_2 and out_3 to have the same shape but different values. When I print them post training I get different values on the training set i.e If I did this
print(out_1)
print(out_2)
print(out_3)
I get different values for the training set (which is my expectation as the losses for each of these branches are different) but when I do the same on validation and test set I get exactly the same values for all of these. How is this possible? What am I doing incorrectly here?