kellanburket kellanburket - 3 months ago 113
Python Question

How to guarantee repartitioning in Spark Dataframe

I'm pretty new to Apache Spark and I'm trying to repartition a dataframe by U.S. State. I then want to break each partition into its own RDD and save to a specific location:

schema = types.StructType([
types.StructField("details", types.StructType([
types.StructField("state", types.StringType(), True)
]), True)
])

raw_rdd = spark_context.parallelize([
'{"details": {"state": "AL"}}',
'{"details": {"state": "AK"}}',
'{"details": {"state": "AZ"}}',
'{"details": {"state": "AR"}}',
'{"details": {"state": "CA"}}',
'{"details": {"state": "CO"}}',
'{"details": {"state": "CT"}}',
'{"details": {"state": "DE"}}',
'{"details": {"state": "FL"}}',
'{"details": {"state": "GA"}}'
]).map(
lambda row: json.loads(row)
)

rdd = sql_context.createDataFrame(raw_rdd).repartition(10, "details.state").rdd

for index in range(0, rdd.getNumPartitions()):
partition = rdd.mapPartitionsWithIndex(
lambda partition_index, partition: partition if partition_index == index else []
).coalesce(1)

if partition.count() > 0:
df = sql_context.createDataFrame(partition, schema=schema)

for event in df.collect():
print "Partition {0}: {1}".format(index, str(event))
else:
print "Partition {0}: No rows".format(index)


In order to test, I load a file from S3 with 50 rows (10 in the example), each with a different state in the
details.state
column. In order to mimic the behavior I've parallelized data in the example above, but the behavior is the same. I get the 50 partitions I asked for but some aren't being used and some carry entries for more than one state. Here's the output for the sample set of 10:

Partition 0: Row(details=Row(state=u'AK'))
Partition 1: Row(details=Row(state=u'AL'))
Partition 1: Row(details=Row(state=u'CT'))
Partition 2: Row(details=Row(state=u'CA'))
Partition 3: No rows
Partition 4: No rows
Partition 5: Row(details=Row(state=u'AZ'))
Partition 6: Row(details=Row(state=u'CO'))
Partition 6: Row(details=Row(state=u'FL'))
Partition 6: Row(details=Row(state=u'GA'))
Partition 7: Row(details=Row(state=u'AR'))
Partition 7: Row(details=Row(state=u'DE'))
Partition 8: No rows
Partition 9: No rows


My question: is the repartitioning strategy just a suggestion to Spark or is there something fundamentally wrong with my code?

Answer

There is nothing unexpected going on here. Spark is using hash of the partitioning key (positive) modulo number of partitions to distribute rows between partitions and with 50 partitions you'll get a significant number of duplicates:

from pyspark.sql.functions import expr

states = sc.parallelize([
    "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DC", "DE", "FL", "GA", 
    "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", 
    "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", 
    "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC", 
    "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
])

states_df = states.map(lambda x: (x, )).toDF(["state"])

states_df.select(expr("pmod(hash(state), 50)")).distinct().count()
# 26

If you want to separate files on write it is better to use partitionBy clause for DataFrameWriter. It will create separate output per level and doesn't require shuffling.

Comments