Raphadasilva Raphadasilva - 1 month ago 19
Python Question

idxmax() equality with pandas

I'm working on a csv file full of electoral data. My raw sample could be represented as :

city party1 party2 party3
0 city1 50 107 114
1 city2 181 323 326
2 city3 26 28 75
3 city4 32 47 59
4 ciy5 8 21 21


I used the idxmax() function of pandas to create a new column, called "winner", like this :

mydf['winner'] = mydf[['party1','party2','party3']].idxmax(axis=1)


My goal was to determine which party was in first position in each city. Here the result :

city party1 party2 party3 winner
0 city1 50 107 114 party3
1 city2 181 323 326 party3
2 city3 26 28 75 party3
3 city4 32 47 59 party3
4 ciy5 8 21 21 party2


The winner's value of the last raw is false, because party2 and party3 have the same score.

Is it posible to include one exception with the function idxmax considery equality of two values ?

Answer

You can use DataFrame.eq for compare subset with DataFrame.max values per row, then sum them and where is value higher as 1 there are duplicates max. So then can be overwrite value of idxmax by mask with mask s > 1:

a = mydf[['party1','party2','party3']]
mydf['winner'] = a.idxmax(axis=1)

s = a.eq(a.max(axis=1), axis=0).sum(axis=1)
print (s)
0    1
1    1
2    1
3    1
4    2
dtype: int64

mydf['winner'] = mydf['winner'].mask(s > 1, 'Equality')
print (mydf)
    city  party1  party2  party3    winner
0  city1      50     107     114    party3
1  city2     181     323     326    party3
2  city3      26      28      75    party3
3  city4      32      47      59    party3
4   ciy5       8      21      21  Equality

If need also values multiple df by values of columns by mul, then apply join and last remove , by strip:

a = mydf[['party1','party2','party3']]
df = a.eq(a.max(axis=1), axis=0)
print (df)
  party1 party2 party3
0  False  False   True
1  False  False   True
2  False  False   True
3  False  False   True
4  False   True   True

mydf['winner'] = df.mul(df.columns.to_series())
                   .apply(','.join, axis=1)
                   .str.strip(',')
print (mydf)
    city  party1  party2  party3         winner
0  city1      50     107     114         party3
1  city2     181     323     326         party3
2  city3      26      28      75         party3
3  city4      32      47      59         party3
4   ciy5       8      21      21  party2,party3
Comments