lizaveta - 1 year ago 119
Python Question

# How to find all argmax in ndarray

I have a 2 dimensional NumPy ndarray.

``````array([[  0.,  20.,  -2.],
[  2.,   1.,   0.],
[  4.,   3.,  20.]])
``````

How can I get all indices of the maximum elements? So I would like as output array([0,1],[2,2]).

Use `np.argwhere` on max-equality mask -

``````np.argwhere(a == a.max())
``````

Sample run -

``````In [552]: a   # Input array
Out[552]:
array([[  0.,  20.,  -2.],
[  2.,   1.,   0.],
[  4.,   3.,  20.]])

In [553]: a == a.max() # Max equality mask
Out[553]:
array([[False,  True, False],
[False, False, False],
[False, False,  True]], dtype=bool)

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]:
array([[0, 1],
[2, 2]])
``````

If you are working with floating point numbers, you might want to use some tolerance there. So, with that consideration, you could use `np.isclose` that has some default absolute and relative tolerance values. This would replace the earlier `a == a.max()` part, like so -

``````In [555]: np.isclose(a, a.max())
Out[555]:
array([[False,  True, False],
[False, False, False],
[False, False,  True]], dtype=bool)
``````
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download