yves yves - 11 days ago 7
Scala Question

How to sum values of a struct in a nested array in a Spark dataframe?

This is in Spark 2.1, Given this input file:


order.json

{"id":1,"price":202.30,"userid":1}
{"id":2,"price":343.99,"userid":1}
{"id":3,"price":399.99,"userid":2}



And the following dataframes:

val order = sqlContext.read.json("order.json")
val df2 = order.select(struct("*") as 'order)
val df3 = df2.groupBy("order.userId").agg( collect_list( $"order").as("array"))


df3 has the following content:

+------+---------------------------+
|userId|array |
+------+---------------------------+
|1 |[[1,202.3,1], [2,343.99,1]]|
|2 |[[3,399.99,2]] |
+------+---------------------------+


and structure:

root
|-- userId: long (nullable = true)
|-- array: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- id: long (nullable = true)
| | |-- price: double (nullable = true)
| | |-- userid: long (nullable = true)


Now assuming I am given df3:


  1. I would like to compute sum of array.price for each userId, taking advantage of having the array per userId rows.

  2. I would add this computation in a new column in the resulting dataframe. Like if I had done df3.withColumn( "sum", lit(0)), but with lit(0) replaced by my computation.



It would have assume to be straighforward, but I am stuck on both. I didnt find any way to access the array as whole do the computation per row (with a foldLeft for example).

Answer

I would like to compute sum of array.price for each userId, taking advantage of having the array

Unfortunately having an array works against you here. Neither Spark SQL nor DataFrame DSL provides tools that could be used directly to handle this task on array of an arbitrary size without decomposing (explode) first.

You can use an UDF:

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{col, udf}

val totalPrice = udf((xs: Seq[Row]) => xs.map(_.getAs[Double]("price")).sum)
df3.withColumn("totalPrice", totalPrice($"array"))
+------+--------------------+----------+ 
|userId|               array|totalPrice|
+------+--------------------+----------+
|     1|[[1,202.3,1], [2,...|    546.29|
|     2|      [[3,399.99,2]]|    399.99|
+------+--------------------+----------+

or convert to statically typed Dataset:

df3
  .as[(Long, Seq[(Long, Double, Long)])]
  .map{ case (id, xs) => (id, xs, xs.map(_._2).sum) }
  .toDF("userId", "array", "totalPrice").show
+------+--------------------+----------+
|userId|               array|totalPrice|
+------+--------------------+----------+
|     1|[[1,202.3,1], [2,...|    546.29|
|     2|      [[3,399.99,2]]|    399.99|
+------+--------------------+----------+

As mentioned above you decompose and aggregate:

import org.apache.spark.sql.functions.{sum, first}

df3
  .withColumn("price", explode($"array.price"))
  .groupBy($"userId")
  .agg(sum($"price"), df3.columns.tail.map(c => first(c).alias(c)): _*)
+------+----------+--------------------+
|userId|sum(price)|               array|
+------+----------+--------------------+
|     1|    546.29|[[1,202.3,1], [2,...|
|     2|    399.99|      [[3,399.99,2]]|
+------+----------+--------------------+

but it is expensive and doesn't use the existing structure.

There is an ugly trick you could use:

import org.apache.spark.sql.functions.{coalesce, lit, max, size}

val totalPrice = (0 to df3.agg(max(size($"array"))).as[Int].first)
  .map(i => coalesce($"array.price".getItem(i), lit(0.0)))
  .foldLeft(lit(0.0))(_ + _)

df3.withColumn("totalPrice", totalPrice)
+------+--------------------+----------+
|userId|               array|totalPrice|
+------+--------------------+----------+
|     1|[[1,202.3,1], [2,...|    546.29|
|     2|      [[3,399.99,2]]|    399.99|
+------+--------------------+----------+

but it is more a curiosity than a real solution.