Philipp_Kats Philipp_Kats - 7 months ago 23
Python Question

How do I flattern a pySpark dataframe by one array column?

I have a spark dataframe like this:

+------+--------+--------------+--------------------+
| dbn| boro|total_students| sBus|
+------+--------+--------------+--------------------+
|17K548|Brooklyn| 399|[B41, B43, B44-SB...|
|09X543| Bronx| 378|[Bx13, Bx15, Bx17...|
|09X327| Bronx| 543|[Bx1, Bx11, Bx13,...|
+------+--------+--------------+--------------------+


How do I flattern it so that each row is copied for each for each element in sBus, and sBus will be a normal string column?

So that result would be like this:

+------+--------+--------------+--------------------+
| dbn| boro|total_students| sBus|
+------+--------+--------------+--------------------+
|17K548|Brooklyn| 399| B41 |
|17K548|Brooklyn| 399| B43 |
|17K548|Brooklyn| 399| B44-SB |
+------+--------+--------------+--------------------+


and so on...

Answer

I can't think of a way to do this without turning it into an RDD. Here's the full code:

from pyspark.sql import SparkContext, SQLContext
sc = SparkContext()
sqlContext = SQLContext(sc)
sc.setLogLevel("FATAL")

df = sqlContext.createDataFrame([
    {'dbn': '17K548', 'boro': 'Brooklyn', 'total_students': 399, 
        'sBus': ['B41', 'B43', 'B44-SB']}, 
    {'dbn': '09X543', 'boro': 'Bronx', 'total_students': 378, 
        'sBus': ['Bx13', 'Bx15', 'Bx17']}, 
    {'dbn': '09X327', 'boro': 'Bronx', 'total_students': 543, 
        'sBus': ['Bx1', 'Bx11', 'Bx13']}
        ])

# convert to rdd
rdd = df.rdd

def extract(row, key):
    """Takes dictionary and key, returns tuple of (dict w/o key, dict[key])."""
    _dict = row.asDict()
    _list = _dict[key]
    del _dict[key]
    return (_dict, _list)


def add_to_dict(_dict, key, value):
    _dict[key] = value
    return _dict


# preserve rest of values in key, put list to flatten in value
rdd = rdd.map(lambda x: extract(x, 'sBus'))
# make a row for each item in value
rdd = rdd.flatMapValues(lambda x: x)
# add flattened value back into dictionary
rdd = rdd.map(lambda x: add_to_dict(x[0], 'sBus', x[1]))
# convert back to dataframe
df = sqlContext.createDataFrame(rdd)

df.show()

The tricky part is keeping the other columns together with the newly flattened values. I do this by mapping each row to a tuple of (dict of other columns, list to flatten) and then calling flatMapValues. This will split each element of the value list into a separate row, but keep the keys attached, i.e.

(key, ['A', 'B', 'C'))

becomes

(key, 'A')
(key, 'B')
(key, 'C')

Then, I move the flattened value back into the dictionary of other columns, and reconvert it back to a DataFrame.

I would've done this entirely from within a DataFrame, but DataFrames can't have key/value tuples and thus they don't have a flatMapValues method.