I have a Matrix of indices I e.g.
I = np.array([[1, 0, 2], [2, 1, 0]])
The index at i-th row selects an element from another Matrix M in the i-th row.
So having M e.g.
M = np.array([[6, 7, 8], [9, 10, 11])
M[I] should select:
[[7, 6, 8], [11, 10, 9]]
I could have:
I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)
but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.
In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.
I had to add a ']' to M.
Advanced indexing with
broadcasting:THe first index has shape (2,1) which pairs with the (2,3) shape of
Ito select a (2,3) block of values.