Corey Corey - 2 months ago 8
Python Question

Why does collect() on a DataFrame with 1 row use 2000 exectors?

This is the simplest DataFrame I could think of. I'm using PySpark 1.6.1.

# one row of data
rows = [ (1, 2) ]
cols = [ "a", "b" ]
df = sqlContext.createDataFrame(rows, cols)

So the data frame completely fits in memory, has no references to any files and looks quite trivial to me.

Yet when I collect the data, it uses 2000 executors:


during collect, 2000 executors are used:

[Stage 2:===================================================>(1985 + 15) / 2000]

and then the expected output:

[Row(a=1, b=2)]

Why is this happening? Shouldn't the DataFrame be completely in memory on the driver?


So I looked into the code a bit to try to figure out what was going on. It seems that sqlContext.createDataFrame really does not make any kind of attempt to set reasonable parameter values based on the data.

Why 2000 tasks?

Spark uses 2000 tasks because my data frame had 2000 partitions. (Even though it seems like clear nonsense to have more partitions than rows.)

This can be seen by:

>>> df.rdd.getNumPartitions()

Why did the DataFrame have 2000 partitions?

This happens because sqlContext.createDataFrame winds up using the default number of partitions (2000 in my case), irrespective of how the data is organized or how many rows it has.

The code trail is as follows.

In sql/, the sqlContext.createDataFrame function calls (in this example):

rdd, schema = self._createFromLocal(data, schema)

which in turn calls:

return self._sc.parallelize(data), schema

And the sqlContext.parallelize function is defined in

numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism

No check is done on the number of rows, and it is not possible to specify the number of slices from sqlContext.createDataFrame.

How can I change how many partitions the DataFrame has?

Using DataFrame.coalesce.

>>> smdf = df.coalesce(1)
>>> smdf.rdd.getNumPartitions()
>>> smdf.explain()
== Physical Plan ==
Coalesce 1
+- Scan ExistingRDD[a#0L,b#1L]
>>> smdf.collect()
[Row(a=1, b=2)]