kjo kjo - 6 months ago 11
Python Question

Aggregating groups into row vectors (rather than scalars)

I want to apply a function to every group in a

groupby
object, so that the function operates on multiple columns of each group, and returns a 1 x n "row vector" as result. I want the n entries of these row vectors to form the contents of n new columns in the resulting
DataFrame
.

Here's an example.

import pandas as pd
import numpy as np

df = pd.DataFrame.from_records([(0, 0, 0.616, 0.559),
(0, 0, 0.976, 0.942),
(0, 0, 0.363, 0.223),
(0, 0, 0.033, 0.225),
(0, 0, 0.950, 0.351),
(0, 1, 0.272, 0.004),
(0, 1, 0.167, 0.177),
(0, 1, 0.520, 0.157),
(0, 1, 0.435, 0.547),
(0, 1, 0.266, 0.850),
(1, 0, 0.368, 0.544),
(1, 0, 0.067, 0.064),
(1, 0, 0.566, 0.533),
(1, 0, 0.102, 0.431),
(1, 0, 0.240, 0.997),
(1, 1, 0.867, 0.793),
(1, 1, 0.519, 0.477),
(1, 1, 0.110, 0.853),
(1, 1, 0.160, 0.155),
(1, 1, 0.735, 0.515)],
columns=list('vwxy'))

grouped = df.groupby(list('vw'))

def example(group):
X2 = np.var(group['x'])
Y2 = np.var(group['y'])
X = np.sqrt(X2)
Y = np.sqrt(Y2)
R2 = X2 + Y2
M = 1.0/(R2 + 1)
return (M * 2 * X, M * 2 * Y, M * (R2 - 1))


This gets close:

grouped.apply(example).reset_index()

# v w 0
# 0 0 0 (0.596122357697, 0.450073544336, -0.664884906839)
# 1 0 1 (0.229241003533, 0.555057863705, -0.799599481139)
# 2 1 0 (0.326212671335, 0.53100544639, -0.782060425392)
# 3 1 1 (0.523276087715, 0.433768876798, -0.733503031723)


...but what I'm after is this:

# v w a b c
# 0 0 0 0.596122 0.450074 -0.664885
# 1 0 1 0.229241 0.555058 -0.799599
# 2 1 0 0.326213 0.531005 -0.782060
# 3 1 1 0.523276 0.433769 -0.733503


How can I achieve this?

It's OK to modify the
example
function, as long as it continues to return all 3 values in some form. IOW, I don't want a solution based on replacing
example
with 3 separate functions, one for each of the output columns.

Answer

Try returning a pandas Series instead of a tuple from example:

def example(group):
    ....
    return pd.Series([M * 2 * X, M * 2 * Y, M * (R2 - 1)], index=list('abc'))