I have a matrix A of shape (n, m, s). At each position in the 0th axis, I need the position corresponding to the maximum in the (m, s)-shaped array.
For example:
np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])
A[0] is:
array([[5, 8, 9],
[5, 0, 0],
[1, 7, 6]])
I want to obtain (0, 2), i.e. the position of 9 here.
I would love to do
aa = A.argmax(), such that aa.shape = (10, 2), and aa[0] = [0, 2]
How can I achieve this?

Using
np.unravel_indexwith a list comprehension:where
blockwill be the3x3(m x s) shaped array in each turn.This gives a list with 10 (
n) entries:You can convert this to a numpy array (of desired shape
(n, 2)):to get: