Fakhri Zouari Fakhri Zouari - 1 year ago 59
Scala Question

Run a Cumulative/Iterative Costum Method on a Column in Spark Scala

Hi I am new to spark/scala, I have been trying - AKA failing, to create a column in a spark dataframe based on a particular recursive formula:

here it is in pseudo code.

someDf.col2[0] = 0

for i > 0
someDf.col2[i] = x * someDf.col1[i-1] + (1-x) * someDf.col2[i-1]


To dive into more details, here is my starting point:
this dataframe is the result of aggregations both on the level of
dates
and individual
id
's.

all further calculations have to happen with respect to that particular
id
, and have to take into consideration what happened in the previous week.

to illustrate this I have simplified the values to zeros and ones and removed the multiplier
x
and
1-x
, and I also have initialized the
col2
to zero.

var someDf = Seq(("2016-01-10 00:00:00.0","385608",0,0),
("2016-01-17 00:00:00.0","385608",0,0),
("2016-01-24 00:00:00.0","385608",1,0),
("2016-01-31 00:00:00.0","385608",1,0),
("2016-02-07 00:00:00.0","385608",1,0),
("2016-02-14 00:00:00.0","385608",1,0),
("2016-01-17 00:00:00.0","105010",0,0),
("2016-01-24 00:00:00.0","105010",1,0),
("2016-01-31 00:00:00.0","105010",0,0),
("2016-02-07 00:00:00.0","105010",1,0)
).toDF("dates", "id", "col1","col2" )

someDf.show()

+--------------------+------+----+----+
| dates| id|col1|col2|
+--------------------+------+----+----+
|2016-01-10 00:00:...|385608| 0| 0|
|2016-01-17 00:00:...|385608| 0| 0|
|2016-01-24 00:00:...|385608| 1| 0|
|2016-01-31 00:00:...|385608| 1| 0|
|2016-02-07 00:00:...|385608| 1| 0|
|2016-02-14 00:00:...|385608| 1| 0|
+--------------------+------+----+----+
|2016-01-17 00:00:...|105010| 0| 0|
|2016-01-24 00:00:...|105010| 1| 0|
|2016-01-31 00:00:...|105010| 0| 0|
|2016-02-07 00:00:...|105010| 1| 0|
+--------------------+------+----+----+


what I have tried so far vs what is desired

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val date_id_window = Window.partitionBy("id").orderBy(asc("dates"))

someDf.withColumn("col2", lag($"col1",1 ).over(date_id_window) +
lag($"col2",1 ).over(date_id_window) ).show()


+--------------------+------+----+----+ / +--------------------+
| dates| id|col1|col2| / | what_col2_should_be|
+--------------------+------+----+----+ / +--------------------+
|2016-01-17 00:00:...|105010| 0|null| / | 0|
|2016-01-24 00:00:...|105010| 1| 0| / | 0|
|2016-01-31 00:00:...|105010| 0| 1| / | 1|
|2016-02-07 00:00:...|105010| 1| 0| / | 1|
+-------------------------------------+ / +--------------------+
|2016-01-10 00:00:...|385608| 0|null| / | 0|
|2016-01-17 00:00:...|385608| 0| 0| / | 0|
|2016-01-24 00:00:...|385608| 1| 0| / | 0|
|2016-01-31 00:00:...|385608| 1| 1| / | 1|
|2016-02-07 00:00:...|385608| 1| 1| / | 2|
|2016-02-14 00:00:...|385608| 1| 1| / | 3|
+--------------------+------+----+----+ / +--------------------+


Is there a way to do this with spark dataframe, I have seen multiple cumulative type computations, but never including the same column, I believe the problem is that the newly computed value for row i-1 is not considered, instead the old i-1 is used which is always 0.

Any help would be appreciated.

Answer Source

Dataset should work just fine:

val x = 0.1

case class Record(dates: String, id: String, col1: Int)

someDf.drop("col2").as[Record].groupByKey(_.id).flatMapGroups((_,  records) => {
  val sorted = records.toSeq.sortBy(_.dates)
  sorted.scanLeft((null: Record, 0.0)){
    case ((_, col2), record) => (record, x * record.col1 + (1 - x) * col2)
  }.tail
}).select($"_1.*", $"_2".alias("col2"))
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download