How to index row elements of a Matrix with a Matrix of indices for each row?

580 Views Asked by At

I have a Matrix of indices I e.g.

I = np.array([[1, 0, 2], [2, 1, 0]])

The index at i-th row selects an element from another Matrix M in the i-th row.

So having M e.g.

M = np.array([[6, 7, 8], [9, 10, 11])

M[I] should select:

[[7, 6, 8], [11, 10, 9]]

I could have:

I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)

but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.

In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.

3

There are 3 best solutions below

0
hpaulj On BEST ANSWER
In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
     ...: M = np.array([[6, 7, 8], [9, 10, 11]])
     ...: 
     ...: I,M

I had to add a ']' to M.

Out[108]: 
(array([[1, 0, 2],
        [2, 1, 0]]),
 array([[ 6,  7,  8],
        [ 9, 10, 11]]))

Advanced indexing with broadcasting:

In [110]: M[np.arange(2)[:,None],I]
Out[110]: 
array([[ 7,  6,  8],
       [11, 10,  9]])

THe first index has shape (2,1) which pairs with the (2,3) shape of I to select a (2,3) block of values.

1
Shaun Han On

How about this one line code? The idea is to enumerate both the rows and the row indices of the matrix, so you can access the corresponding rows in the indexing matrix.

import numpy as np

I = np.array([[1, 0, 2], [2, 1, 0]])
M = np.array([[6, 7, 8], [9, 10, 11]])

Result = np.array([row[I[i]] for i, row in enumerate(M)])
print(Result)

Output:

[[ 7  6  8]
 [11 10  9]]
0
Mustafa Aydın On

np.take_along_axis can also be used here to take values of M using indices I over axis=1:

>>> np.take_along_axis(M, I, axis=1)

array([[ 7,  6,  8],
       [11, 10,  9]])