Bibi541 Bibi541 - 2 months ago 8
Scala Question

Calculate cumulative counts for `t` time periods without hardcoding

I want to calculated cumulative counts at different time steps. I have counts of events that happened during each time period

t
: now I want the cumulative number of events up to and including that period.

I can easily compute each cumulation separately, but it is tedious. I can append them back together with a
UnionAll
, but this would be tedious too, with a large number of time periods.

How could I do this more cleanly?

package main.scala

import java.io.File
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._

object Test {

def main(args: Array[String]) {

// Spark and SQL Context (gives access to Spark and Spark SQL libraries)
val conf = new SparkConf().setAppName("Merger")
val sc = new SparkContext(conf)
val sqlContext = SQLContextSingleton.getInstance(sc)
import sqlContext.implicits._

// Count
val count = Seq(("A",1,1),("A",1,2),("A",0,3),("A",0,4),("A",0,5),("A",1,6),
("B",1,1),("B",0,2),("B",0,3),("B",1,4),("B",0,5),("B",1,6))
.toDF("id","count","t")

val count2 = count.filter('t <= 2).groupBy('id).agg(sum("count"), max("t"))

val count3 = count.filter('t <= 3).groupBy('id).agg(sum("count"), max("t"))

count.show()
count2.show()
count3.show()
}
}


count
:

+---+-----+---+
| id|count| t|
+---+-----+---+
| A| 1| 1|
| A| 1| 2|
| A| 0| 3|
| A| 0| 4|
| A| 0| 5|
| A| 1| 6|
| B| 1| 1|
| B| 0| 2|
| B| 0| 3|
| B| 1| 4|
| B| 0| 5|
| B| 1| 6|
+---+-----+---+


count2
:

+---+----------+------+
| id|sum(count)|max(t)|
+---+----------+------+
| A| 2| 2|
| B| 1| 2|
+---+----------+------+


count3
:

+---+----------+------+
| id|sum(count)|max(t)|
+---+----------+------+
| A| 2| 3|
| B| 1| 3|
+---+----------+------+

Answer

I found this in the Stackoverflow Documentation.

I have tested it with Spark 1.5.2/Scala 10 and Spark 2.0.0/Scala 11 and it worked like a charm. It didn't work with Spark 1.6.2, I suspect it is because it is not compiled with Hive.

package main.scala

import java.io.File
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext


object Test {

    def main(args: Array[String]) {

        val conf = new SparkConf().setAppName("Test")
        val sc = new SparkContext(conf)
        val sqlContext = SQLContextSingleton.getInstance(sc)
        import sqlContext.implicits._

        val data = Seq(("A",1,1,1),("A",3,1,3),("A",0,0,2),("A",4,0,4),("A",0,0,6),("A",2,1,5),
                         ("B",0,1,3),("B",0,0,4),("B",2,0,1),("B",2,1,2),("B",0,0,6),("B",1,1,5))
            .toDF("id","param1","param2","t")
        data.show()

        data.withColumn("cumulativeSum1", sum("param1").over( Window.partitionBy("id").orderBy("t")))
            .withColumn("cumulativeSum2", sum("param2").over( Window.partitionBy("id").orderBy("t")))
            .show()
    }
}

An improvement I am working on is to be able to apply it to several columns at once, instead of repeating withColumn. Inputs welcome!

Comments