clay clay - 5 months ago 67
Scala Question

Spark merge/combine arrays in groupBy/aggregate

The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.

When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.

Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.

So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?

Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?

val sparkSession = SparkSession.builder.
.appName("merge list test")

val schema = StructType(
StructField("category", IntegerType) ::
StructField("array_value_1", ArrayType(StringType)) ::
StructField("array_value_2", ArrayType(StringType)) ::

val rows = List(
Row(1, List("a", "b"), List("u", "v")),
Row(1, List("b", "c"), List("v", "w")),
Row(2, List("c", "d"), List("w")),
Row(2, List("c", "d", "e"), List("x", "y", "z"))

val df = sparkSession.createDataFrame(rows.asJava, schema)

val dfExploded = df.
withColumn("scalar_1", explode(col("array_value_1"))).
withColumn("scalar_2", explode(col("array_value_2")))

// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19"dfExploded.count()=${dfExploded.count()}")

val dfOutput = dfExploded.groupBy("category").agg(


It could be inefficient to explode but fundamentally the operation you try to implement is simply expensive. Effectively it is just another groupByKey and there is not much you can do here to make it better. Since you use Spark > 2.0 you could collect_list directly and flatten:

val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten.distinct)


It is also possible to use custom Aggregator but I doubt any of these will make a huge difference.

If sets are relatively large and you expect significant number of duplicates you could try to use aggregateByKey with mutable sets:

import scala.collection.mutable.{Set => MSet}

val rdd = df
  .select($"category", struct($"array_value_1", $"array_value_2"))
  .as[(Int, (Seq[String], Seq[String]))]

val agg = rdd
  .aggregateByKey((MSet[String](), MSet[String]()))( 
    {case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
    {case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
  .mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }