galah92 - 2 months ago 35
Python Question

# MATLAB sort() vs Numpy argsort() - how to match results?

I'm porting a MATLAB code to Python's Numpy.

In MATLAB (Octave, actually), I have something like:

``````>> someArr = [9, 8, 7, 7]
>> [~, ans] = sort(someArr, 'descend')
ans =
1   2   3   4
``````

So in Numpy I'm doing:

``````>>> someArr = np.array([9, 8, 7, 7])
>>> np.argsort(someArr)[::-1]
array([0, 1, 3, 2])
``````

I got
`1, 2, 3, 4`
in MATLAB while
`0, 1, 3, 2`
on Numpy, and I need
`0, 1, 2, 3`
on Numpy.

I believe it due to the sorting algorithm used in each function, But I checked and it looks like both are using "quicksort" (see here and here).

How can I match the Numpy's solution to the MATLAB one?

In order to make this work, we need to be a little bit clever. `numpy` doesn't have the `'descend'` analog. You're mimicking it by reversing the results of the sort (which is ultimately you're undoing).

I'm not sure how `matlab` accomplishes it, but they claim to use a stable variant of quicksort. Specifically, for descending sorts:

If the flag is 'descend', the output is reversed just before being returned. After the reversal, we ran the index vector sorts to restore stability.

It appears that `octave` follows suit here.

Since their sort is stable, you have the guarantee that the order of equal values in the input will be preserved in the output. `numpy` on the other hand makes no such guarantees for it's quicksort. If we want a stable sort in numpy, we need to use `mergesort`:

``````>>> np.argsort(someArr, kind='mergesort')
array([2, 3, 1, 0])
``````

Ok, this output makes sense. `someArr[2] == someArr[3]` and the third element comes before the fourth so it is sensible that `2` would be before `3` in the output (without a guaranteed stable sorting algorithm, we couldn't make this claim). Now comes the clever stepping... You want the "descending" values, but rather than reversing the output of `argsort`, why don't we negate the input? That will have the effect of larger numbers sorting before than their lower counterparts just as a descending sort would do...

``````>>> np.argsort(-someArr, kind='quicksort')
array([0, 1, 2, 3])
``````

Now we're talkin! And since mergesort is guaranteed to be stable, elements (with equal values) which appear at lower indices will appear in the output first -- Just like matlab/octave. Nice.