maxymoo maxymoo - 11 months ago 48
Scala Question

How to get spark dataframes from grouped data

I have a dataframe, and I want to group by a column and turn the groups back into dataframes with the same schema. The reason is that I want to map a function with signature

DataFrame -> String
across the groups. Here is what I'm trying:

val df = sc.parallelize(Seq((1,2,3),(1,2,4),(2,3,4))).toDF
val schema = df.schema
val groups = df.rdd.groupBy(x => x(0))
.mapValues(g => sqlContext.createDataFrame(sc.makeRDD(g.toList), schema))

Here's what I'm hoping for:

scala> groups(0)._2.collect
Array[org.apache.spark.sql.Row] = Array([1,2,3], [1,2,4])

but it's not working (the tasks are failing with
) ... I guess you cant map a function that refers to the spark context, but i'm not sure how else to achieve this?

Answer Source

I guess you cant map a function that refers to the spark context

Correct - you can't use any of Spark's context objects (or RDDs, or Dataframes) inside a function passed to any of Spark's higher-order functions, as that would require serializing these objects and sending them to the executors, but they are intentionally not serializable, because it wouldn't make sense (each executor would then have to behave like another driver application).

To achieve a Dataframe containing only one "group", I'd recommend using filter instead of groupBy: You can first collect all the group keys, then map each one to a filtered Dataframe:

val df = sc.parallelize(Seq((1,2,3),(1,2,4),(2,3,4))).toDF

df.cache() // EDIT: this might speed this up significantly, as DF will be reused instead of recalculated for each key 

val groupKeys: Array[Int] = { case Row(i: Int, _, _) => i }.distinct().collect()
val dfPerKey: Array[DataFrame] = => df.filter($"_1" === k))

// prints:
//    +---+---+---+
//    | _1| _2| _3|
//    +---+---+---+
//    |  1|  2|  3|
//    |  1|  2|  4|
//    +---+---+---+
//    +---+---+---+
//    | _1| _2| _3|
//    +---+---+---+
//    |  2|  3|  4|
//    +---+---+---+