Knight71 Knight71 - 2 months ago 23
Scala Question

How to compute cumulative sum using Spark

I have an rdd of (String,Int) which is sorted by key

val data = Array(("c1",6), ("c2",3),("c3",4))
val rdd = sc.parallelize(data).sortByKey


Now I want to start the value for the first key with zero and the subsequent keys as sum of the previous keys.

Eg: c1 = 0 , c2 = c1's value , c3 = (c1 value +c2 value) , c4 = (c1+..+c3 value)
expected output:

(c1,0), (c2,6), (c3,9)...


Is it possible to achieve this ?
I tried it with map but the sum is not preserved inside the map.

var sum = 0 ;
val t = keycount.map{ x => { val temp = sum; sum = sum + x._2 ; (x._1,temp); }}

Answer
  1. Compute partial results for each partition:

    val partials = rdd.mapPartitionsWithIndex((i, iter) => {
      val (keys, values) = iter.toSeq.unzip
      val sums  = values.scanLeft(0)(_ + _)
      Iterator((keys.zip(sums.tail), sums.last))
    })
    
  2. Collect partials sums

    val partialSums = partials.values.collect
    
  3. Compute cumulative sum over partitions and broadcast it:

    val sumMap = sc.broadcast(
      (0 until rdd.partitions.size)
        .zip(partialSums.scanLeft(0)(_ + _))
        .toMap
    )
    
  4. Compute final results:

    val result = partials.keys.mapPartitionsWithIndex((i, iter) => {
      val offset = sumMap.value(i)
      if (iter.isEmpty) Iterator()
      else iter.next.map{case (k, v) => (k, v + offset)}.toIterator
    })