Amani Amani - 7 months ago 11
Python Question

Pandas, Avoiding hierarchy in pivot table

I have a pandas data frame,

df
, from which a pivot table is generated using the following function;

def objective2(excel_file):
df = pd.read_excel(excel_file)

# WBC cut-offs
df['WBC_groups'] = pd.cut(df.WBC, [0, 4, 12, 100],
labels=['WBC < 4', 'WBC Normal', 'WBC > 12'])

df['count'] = 1

table = df.pivot_table('count', index=['Sex'],
columns=['WBC_groups', 'Outcome_at_24'],
aggfunc='sum',
margins=True, margins_name='Total')

return table


This generate the following table:

WBC_groups WBC < 4 WBC Normal WBC > 12 Total
Outcome_at_24 Alive Died Alive Died Alive Died
Sex
Female 10.0 2.0 20.0 6.0 14.0 NaN 86.0
Male 3.0 NaN 28.0 3.0 26.0 4.0 111.0
Total 13.0 2.0 48.0 9.0 40.0 4.0 197.0


How can I avoid the hierarchy in the columns so that the table looks like this:

WBC_groups WBC < 4 WBC Normal WBC > 12 Alive Died Total
Sex
Female 10.0 2.0 20.0 6.0 14.0 86.0
Male 3.0 NaN 28.0 3.0 26.0 111.0
Total 13.0 2.0 48.0 9.0 40.0 197.0


Note: data in the tables are not accurate, just dummies.

Answer

I think you cannot avoiding hierarchy, because in pivot_table use parameter columns with two columns - WBC_groups and Outcome_at_24.

The easiest solution is set new column names and then drop column rem:

df.columns = ['WBC < 4', 'WBC Normal', 'WBC > 12', 'Alive', 'Died', 'rem', 'Total']
df = df.drop('rem', axis=1)
print df
        WBC < 4  WBC Normal  WBC > 12  Alive  Died  Total
Sex                                                      
Female     10.0         2.0      20.0    6.0  14.0   86.0
Male        3.0         NaN      28.0    3.0  26.0  111.0
Total      13.0         2.0      48.0    9.0  40.0  197.0

But if you need more general solution:

print df
WBC_groups    WBC < 4      WBC Normal      WBC > 12       Total
Outcome_at_24   Alive Died      Alive Died    Alive Died       
Sex                                                            
Female           10.0  2.0       20.0  6.0     14.0  NaN   86.0
Male              3.0  NaN       28.0  3.0     26.0  4.0  111.0
Total            13.0  2.0       48.0  9.0     40.0  4.0  197.0

cols1 = df.columns.get_level_values('WBC_groups').to_series().drop_duplicates().tolist()
print cols1
['WBC < 4', 'WBC Normal', 'WBC > 12', 'Total']

cols2 = df.columns.get_level_values('Outcome_at_24').to_series().drop_duplicates().tolist()
print cols2
['Alive', 'Died', ' ']

cols = cols1[:-1] + cols2[:2] + ['rem'] + cols1[-1:]
print cols
['WBC < 4', 'WBC Normal', 'WBC > 12', 'Alive', 'Died', 'rem', 'Total']

df.columns = cols

df = df.drop('rem', axis=1)
print df
        WBC < 4  WBC Normal  WBC > 12  Alive  Died  Total
Sex                                                      
Female     10.0         2.0      20.0    6.0  14.0   86.0
Male        3.0         NaN      28.0    3.0  26.0  111.0
Total      13.0         2.0      48.0    9.0  40.0  197.0
Comments