Torch argmax() returns different results for tensor on Mac MPS and CPU

99 Views Asked by At

Torch argmax() returns correct results for tensor on CPU, but incorrect on MPS.

I have a matrix called target_mixtures, where print(target_mixtures[0]) would give

tensor([0.0010, 0.6827, 0.0010, 0.0261, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.2744, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010], device='mps:0')

When I run print(torch.argmax(target_mixtures[0].cpu())), it gives 1 as expected (the 0.6827 is largest).

However, when I run print(torch.argmax(target_mixtures[0])), it gives 0. Not having it on cpu (so on Mac MPS), it gives 0 for all these target_mixtures[i].

I also attach a screenshot of my jupyter notebook, where these are just the 3 consecutive cells executed.

The actual code

0

There are 0 best solutions below