bits bits - 29 days ago 13
Python Question

In a 3d array in numpy. How can I extract indices of max element in third dimension?

Example input 3D array of shape (2,2,2):

[[[ 1, 2],
[ 4, 3]],
[[ 5, 6],
[ 8, 7]]]


My 3d array has a shape of (N, N, N), in above example N = 2.

I need to get all indices such that index for third dimension belongs to max element in third dimension, Output for above 3D array:

[[0, 0, 1], # for element 2
[0, 1, 0], # for element 4
[1, 0, 1], # for element 6
[1, 1, 0]] # for element 8


It would be great if I can do that with
argmax
or
argwhere
function. I want to avoid iteration and see if its possible to do this using numpy functions.

Answer

Here's an approach using np.meshgrid to get all the indices along the first and second axes and then stacking them alongwith the max indices from the third axis using np.column_stack -

d = a.argmax(-1)
m,n = a.shape[:2]
c,r = np.mgrid[:m,:n]
out = np.column_stack((c.ravel(),r.ravel(),d.ravel()))

Sample run -

In [96]: a
Out[96]: 
array([[[38, 49, 15, 61, 29],
        [31, 88, 45, 88, 20],
        [17, 97, 58, 61, 14],
        [43, 77, 56, 92, 89]],

       [[48, 91, 49, 35, 58],
        [53, 34, 58, 92, 52],
        [20, 35, 70, 41, 81],
        [60, 42, 85, 82, 41]],

       [[45, 41, 32, 41, 25],
        [59, 32, 90, 18, 47],
        [24, 93, 29, 89, 12],
        [80, 27, 12, 51, 33]]])

In [97]: out
Out[97]: 
array([[0, 0, 3],
       [0, 1, 1],
       [0, 2, 1],
       [0, 3, 3],
       [1, 0, 1],
       [1, 1, 3],
       [1, 2, 4],
       [1, 3, 2],
       [2, 0, 0],
       [2, 1, 2],
       [2, 2, 1],
       [2, 3, 0]])

Alternatively, since those indices are basically repetitions, we can use np.repeat and np.tile to get those indices arrays and then use np.column_stack as before, like so -

d0 = np.arange(m).repeat(n)
d1 = np.tile(np.arange(n),m)
out = np.column_stack((d0,d1,d.ravel()))
Comments