Torch argmax() returns correct results for tensor on CPU, but incorrect on MPS.
I have a matrix called target_mixtures, where print(target_mixtures[0]) would give
tensor([0.0010, 0.6827, 0.0010, 0.0261, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.2744, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010], device='mps:0')
When I run print(torch.argmax(target_mixtures[0].cpu())), it gives 1 as expected (the 0.6827 is largest).
However, when I run print(torch.argmax(target_mixtures[0])), it gives 0. Not having it on cpu (so on Mac MPS), it gives 0 for all these target_mixtures[i].
I also attach a screenshot of my jupyter notebook, where these are just the 3 consecutive cells executed.