gsamaras gsamaras - 3 months ago 28
Python Question

How to balance my data across the partitions?

Edit: The answer helps, but I described my solution in: memoryOverhead issue in Spark.




I have an RDD with 202092 partitions, which reads a dataset created by others. I can manually see that the data is not balanced across the partitions, for example some of them have 0 images and other have 4k, while the mean lies at 432. When processing the data, I got this error:

Container killed by YARN for exceeding memory limits. 16.9 GB of 16 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead.


while memoryOverhead is already boosted. I feel that some spikes are happening which make Yarn kill my container, because that spike overflows the specified borders.

So what should I do make sure that my data are (roughly) balanced across partitions?




My idea was that repartition() would work, it invokes shuffling:

dataset = dataset.repartition(202092)


but I just got the very same error, despite the programming-guide's instructions:


repartition(numPartitions)

Reshuffle the data in the RDD randomly to create either more or fewer
partitions and balance it across them. This always shuffles all data
over the network.





Check my toy example though:

data = sc.parallelize([0,1,2], 3).mapPartitions(lambda x: range((x.next() + 1) * 1000))
d = data.glom().collect()
len(d[0]) # 1000
len(d[1]) # 2000
len(d[2]) # 3000
repartitioned_data = data.repartition(3)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 1854
len(re_d[1]) # 1754
len(re_d[2]) # 2392
repartitioned_data = data.repartition(6)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 422
len(re_d[1]) # 845
len(re_d[2]) # 1643
len(re_d[3]) # 1332
len(re_d[4]) # 1547
len(re_d[5]) # 211
repartitioned_data = data.repartition(12)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 132
len(re_d[1]) # 265
len(re_d[2]) # 530
len(re_d[3]) # 1060
len(re_d[4]) # 1025
len(re_d[5]) # 145
len(re_d[6]) # 290
len(re_d[7]) # 580
len(re_d[8]) # 1113
len(re_d[9]) # 272
len(re_d[10]) # 522
len(re_d[11]) # 66

Answer

The memory overhead limit exceeding issue I think is due to DirectMemory buffers used during fetch. I think it's fixed in 2.0.0. (We had the same issue, but stopped digging much deeper when we found that upgrading to 2.0.0 resolved it. Unfortunately I don't have Spark issue numbers to back me up.)


The uneven partitions after repartition are surprising. Contrast with https://github.com/apache/spark/blob/v2.0.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L443. Spark even generates random keys in repartition, so it is not done with a hash that could be biased.

I tried your example and get the exact same results with Spark 1.6.2 and Spark 2.0.0. But not from Scala spark-shell:

scala> val data = sc.parallelize(1 to 3, 3).mapPartitions { it => (1 to it.next * 1000).iterator }
data: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[6] at mapPartitions at <console>:24

scala> data.mapPartitions { it => Iterator(it.toSeq.size) }.collect.toSeq
res1: Seq[Int] = WrappedArray(1000, 2000, 3000)

scala> data.repartition(3).mapPartitions { it => Iterator(it.toSeq.size) }.collect.toSeq
res2: Seq[Int] = WrappedArray(1999, 2001, 2000)

scala> data.repartition(6).mapPartitions { it => Iterator(it.toSeq.size) }.collect.toSeq
res3: Seq[Int] = WrappedArray(999, 1000, 1000, 1000, 1001, 1000)

scala> data.repartition(12).mapPartitions { it => Iterator(it.toSeq.size) }.collect.toSeq
res4: Seq[Int] = WrappedArray(500, 501, 501, 501, 501, 500, 499, 499, 499, 499, 500, 500)

Such beautiful partitions!


(Sorry this is not a full answer. I just wanted to share my findings so far.)