lizaveta lizaveta - 7 days ago 4
Python Question

How to find all argmax in ndarray

I have a 2 dimensional NumPy ndarray.

array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])


How can I get all indices of the maximum elements? So I would like as output array([0,1],[2,2]).

Answer

Use np.argwhere on max-equality mask -

np.argwhere(a == a.max())

Sample run -

In [552]: a   # Input array
Out[552]: 
array([[  0.,  20.,  -2.],
       [  2.,   1.,   0.],
       [  4.,   3.,  20.]])

In [553]: a == a.max() # Max equality mask
Out[553]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]: 
array([[0, 1],
       [2, 2]])

If you are working with floating point numbers, you might want to use some tolerance there. So, with that consideration, you could use np.isclose that has some default absolute and relative tolerance values. This would replace the earlier a == a.max() part, like so -

In [555]: np.isclose(a, a.max())
Out[555]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)