ohruunuruus ohruunuruus - 3 months ago 27
Scala Question

Defining a UDF that accepts an Array of objects in a Spark DataFrame?

When working with Spark's DataFrames, User Defined Functions (UDFs) are required for mapping data in columns. UDFs require that argument types are explicitly specified. In my case, I need to manipulate a column that is made up of arrays of objects, and I do not know what type to use. Here's an example:

import sqlContext.implicits._

// Start with some data. Each row (here, there's only one row)
// is a topic and a bunch of subjects
val data = sqlContext.read.json(sc.parallelize(Seq(
"""
|{
| "topic" : "pets",
| "subjects" : [
| {"type" : "cat", "score" : 10},
| {"type" : "dog", "score" : 1}
| ]
|}
""")))


It's relatively straightforward to use the built-in
org.apache.spark.sql.functions
to perform basic operations on the data in the columns

import org.apache.spark.sql.functions.size
data.select($"topic", size($"subjects")).show

+-----+--------------+
|topic|size(subjects)|
+-----+--------------+
| pets| 2|
+-----+--------------+


and it's generally easy to write custom UDFs to perform arbitrary operations

import org.apache.spark.sql.functions.udf
val enhance = udf { topic : String => topic.toUpperCase() }
data.select(enhance($"topic"), size($"subjects")).show

+----------+--------------+
|UDF(topic)|size(subjects)|
+----------+--------------+
| PETS| 2|
+----------+--------------+


But what if I want to use a UDF to manipulate the array of objects in the "subjects" column? What type do I use for the argument in the UDF? For example, if I want to reimplement the size function, instead of using the one provided by spark:

val my_size = udf { subjects: Array[Something] => subjects.size }
data.select($"topic", my_size($"subjects")).show


Clearly
Array[Something]
does not work... what type should I use!? Should I ditch
Array[]
altogether? Poking around tells me
scala.collection.mutable.WrappedArray
may have something to do with it, but still there's another type I need to provide.

Answer

What you're looking for is Seq[o.a.s.sql.Row]:

import org.apache.spark.sql.Row

val my_size = udf { subjects: Seq[Row] => subjects.size }

Explanation:

  • Current representation of ArrayType is, as you already know, WrappedArray so Array won't work and it is better to stay on the safe side.
  • Local type for StructType is Row. Unfortunately it means that access to the individual fields is not type safe.