A.Yazdiha A.Yazdiha - 3 months ago 14
Python Question

What does it mean to sort an array/matrix by the argmax as the key?

I am struggling to understand the mechanism behind a function around the topic of sorting in numpy.

import numpy as np
arr = [[8, 5, 9],
[3, 9.5, 5], [5.5, 4, 3.5], [6, 2, 1],
[6,1,2],[3,2,1],[8,5,3]]
res = sorted(arr, key=np.argmax)


This gives me the following result:

print(res)
[[5.5, 4, 3.5], [6, 2, 1], [6, 1, 2],
[3, 2, 1], [8, 5, 3], [3, 9.5, 5], [8, 5, 9]]


I am an R user and not very familiar with Python. I might have some clue about the role of the 'key' argument, but for this example specifically I ask for your help.
In a simple case if the
key
argument is defined as a function which returns the first element, then the
sorted
, sorts the array based on its first element, but I cannot see how this works with the
argmax
.
Thanks,

Answer

The argmax function returns the indice of the biggest element. It is used as a key in the sort function.

If you print this:

print([np.argmax(x) for x in arr])

you get:

[2, 1, 0, 0, 0, 0, 0]

which explains the sorting. Last elements appear first in your result, first element appears last because it has the highest criteria, and second element appears just before.

Of course this is a "weak" sorting since the criteria often returns the same value and thus the result depends on the order of the initial list (edit: this is called a stable sorting, see interesting Bakuriu comment)