Brian Bruggeman Brian Bruggeman - 5 months ago 83
JSON Question

Pyspark: How to transform json strings in a dataframe column

The following is more or less straight python code which functionally extracts exactly as I want. The data schema for the column I'm filtering out within the dataframe is basically a json string.

However, I had to greatly bump up the memory requirement for this and I'm only running on a single node. Using a collect is probably bad and creating all of this on a single node really isn't taking advantage of the distributed nature of Spark.

I'd like a more Spark centric solution. Can anyone help me massage the logic below to better take advantage of Spark? Also, as a learning point: please provide an explanation for why/how the updates make it better.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json

from pyspark.sql.types import SchemaStruct, SchemaField, StringType

input_schema = SchemaStruct([
SchemaField('scrubbed_col_name', StringType(), nullable=True)

output_schema = SchemaStruct([
SchemaField('val01_field_name', StringType(), nullable=True),
SchemaField('val02_field_name', StringType(), nullable=True)

example_input = [
'''[{"val01_field_name": "val01_a", "val02_field_name": "val02_a"},
{"val01_field_name": "val01_a", "val02_field_name": "val02_b"},
{"val01_field_name": "val01_b", "val02_field_name": "val02_c"}]''',
'''[{"val01_field_name": "val01_c", "val02_field_name": "val02_a"}]''',
'''[{"val01_field_name": "val01_a", "val02_field_name": "val02_d"}]''',

desired_output = {
'val01_a': ['val_02_a', 'val_02_b', 'val_02_d'],
'val01_b': ['val_02_c'],
'val01_c': ['val_02_a'],

def capture(dataframe):
# Capture column from data frame if it's not empty
data = dataframe.filter('scrubbed_col_name != null')\

# Create a mapping of val1: list(val2)
mapping = {}
# For every row in the rdd
for row in data:
# For each json_string within the row
for json_string in row:
# For each item within the json string
for val in json.loads(json_string):
# Extract the data properly
val01 = val.get('val01_field_name')
val02 = val.get('val02_field_name')
if val02 not in mapping.get(val01, []):
mapping.setdefault(val01, []).append(val02)
return mapping


One possible solution:

  .rdd  # Convert to rdd
  .flatMap(lambda x: x)  # Flatten rows
  # Parse JSON. In practice you should add proper exception handling
  .flatMap(lambda x: json.loads(x))
  # Get values
  .map(lambda x: (x.get('val01_field_name'), x.get('val02_field_name')))
  # Convert to final shape

Given output specification this operation is not exactly efficient (do you really require grouped values?) but still much better than collect.