galah92 galah92 - 2 months ago 35
Python Question

MATLAB sort() vs Numpy argsort() - how to match results?

I'm porting a MATLAB code to Python's Numpy.

In MATLAB (Octave, actually), I have something like:

>> someArr = [9, 8, 7, 7]
>> [~, ans] = sort(someArr, 'descend')
ans =
1 2 3 4


So in Numpy I'm doing:

>>> someArr = np.array([9, 8, 7, 7])
>>> np.argsort(someArr)[::-1]
array([0, 1, 3, 2])


I got
1, 2, 3, 4
in MATLAB while
0, 1, 3, 2
on Numpy, and I need
0, 1, 2, 3
on Numpy.

I believe it due to the sorting algorithm used in each function, But I checked and it looks like both are using "quicksort" (see here and here).

How can I match the Numpy's solution to the MATLAB one?

Answer

In order to make this work, we need to be a little bit clever. numpy doesn't have the 'descend' analog. You're mimicking it by reversing the results of the sort (which is ultimately you're undoing).

I'm not sure how matlab accomplishes it, but they claim to use a stable variant of quicksort. Specifically, for descending sorts:

If the flag is 'descend', the output is reversed just before being returned. After the reversal, we ran the index vector sorts to restore stability.

It appears that octave follows suit here.

Since their sort is stable, you have the guarantee that the order of equal values in the input will be preserved in the output. numpy on the other hand makes no such guarantees for it's quicksort. If we want a stable sort in numpy, we need to use mergesort:

>>> np.argsort(someArr, kind='mergesort')
array([2, 3, 1, 0])

Ok, this output makes sense. someArr[2] == someArr[3] and the third element comes before the fourth so it is sensible that 2 would be before 3 in the output (without a guaranteed stable sorting algorithm, we couldn't make this claim). Now comes the clever stepping... You want the "descending" values, but rather than reversing the output of argsort, why don't we negate the input? That will have the effect of larger numbers sorting before than their lower counterparts just as a descending sort would do...

>>> np.argsort(-someArr, kind='quicksort')
array([0, 1, 2, 3])

Now we're talkin! And since mergesort is guaranteed to be stable, elements (with equal values) which appear at lower indices will appear in the output first -- Just like matlab/octave. Nice.

Comments