Justin Raymond Justin Raymond - 6 months ago 59
Scala Question

Spark UI DAG stage disconnected

I ran the following job in the spark-shell:

val d = sc.parallelize(0 until 1000000).map(i => (i%100000, i)).persist
d.join(d.reduceByKey(_ + _)).collect

The Spark UI shows three stages. Stage 4 and 5 correspond to the computation of
, and stage 6 corresponds to the computation of the
action. Since
is persisted, I would expect only two stages. However stage 5 is present not connected to any other stages.

Spark UI DAG

So tried running the same computation without using persist, and the DAG looks like identically, except without the green dots indicating the RDD has been persisted.

Spark UI DAG without persist

I would expect the output of stage 11 to be connect to the input of stage 12, but it is not.

Looking at the stage descriptions, the stages seem to indicate that
is being persisted, because stage 5 has input, but I am still confused as to why stage 5 even exists.

Spark UI stages

Spark UI stages without persist

  1. Input RDD is cached and cached part is not recomputed.

    This can be validated with a simple test:

    import org.apache.spark.SparkContext
    def f(sc: SparkContext) = {
      val counter = sc.longAccumulator("counter")
      val rdd = sc.parallelize(0 until 100).map(i => {
        (i%10, i)
      rdd.join(rdd.reduceByKey(_ + _)).foreach(_ => ())
    assert(f(spark.sparkContext) == 100)
  2. Caching doesn't remove stages from DAG.

    If data is cached corresponding stages can be marked as skipped but are still part of the DAG. Lineage can be truncated using checkpoints but it is not the same thing and it doesn't remove stages from visualization.

  3. Input stages contain more than cached computations.

    Spark stages group together operations which can be chained without performing shuffle.

    While part of the input stage is cached it doesn't cover all the operations required to prepare shuffle files. This is why you don't see skipped tasks.

  4. The rest (detachment) is just a limitation of the graph visualization.

  5. If you repartition data first:

    import org.apache.spark.HashPartitioner
    val d = sc.parallelize(0 until 1000000)
      .map(i => (i%100000, i))
      .partitionBy(new HashPartitioner(20))
    d.join(d.reduceByKey(_ + _)).collect

    you'll get DAG you're most likely looking for:

    enter image description here