Scala Question

How to apply a custom filtering function on a Spark DataFrame

I have a DataFrame of the form:

A_DF = |id_A: Int|concatCSV: String|


and another one:

B_DF = |id_B: Int|triplet: List[String]|


Examples of
concatCSV
could look like:

"StringD, StringB, StringF, StringE, StringZ"
"StringA, StringB, StringX, StringY, StringZ"
...


while a
triplet
is something like:

("StringA", "StringF", "StringZ")
("StringB", "StringU", "StringR")
...


I want to produce the cartesian set of
A_DF
and
B_DF
, e.g.;

| id_A: Int | concatCSV: String | id_B: Int | triplet: List[String] |
| 14 | "StringD, StringB, StringF, StringE, StringZ" | 21 | ("StringA", "StringF", "StringZ")|
| 14 | "StringD, StringB, StringF, StringE, StringZ" | 45 | ("StringB", "StringU", "StringR")|
| 18 | "StringA, StringB, StringX, StringY, StringG" | 21 | ("StringA", "StringF", "StringZ")|
| 18 | "StringA, StringB, StringX, StringY, StringG" | 45 | ("StringB", "StringU", "StringR")|
| ... | | | |


Then keep just the records that have at least two substrings (e.g
StringA, StringB
) from
A_DF("concatCSV")
that appear in
B_DF("triplet")
, i.e. use
filter
to exclude those that don't satisfy this condition.

First question is: can I do this without converting the DFs into RDDs?

Second question is: can I ideally do the whole thing in the
join
step--as a
where
condition?

I have tried experimenting with something like:

val cartesianRDD = A_DF
.join(B_DF,"right")
.where($"triplet".exists($"concatCSV".contains(_)))


but
where
cannot be resolved. I tried it with
filter
instead of
where
but still no luck. Also, for some strange reason, type annotation for
cartesianRDD
is
SchemaRDD
and not
DataFrame
. How did I end up with that? Finally, what I am trying above (the short code I wrote) is incomplete as it would keep records with just one substring from
concatCSV
found in
triplet
.

So, third question is: Should I just change to RDDs and solve it with a custom filtering function?

Finally, last question: Can I use a custom filtering function with DataFrames?

Thanks for the help.

Answer

The function CROSS JOIN is implemented in Hive, so you could first do the cross-join using Hive SQL:

A_DF.registerTempTable("a")
B_DF.registerTempTable("b")

// sqlContext should be really a HiveContext
val result = sqlContext.sql("SELECT * FROM a CROSS JOIN b") 

Then you can filter down to your expected output using two udf's. One that converts your string to an array of words, and a second one that gives us the length of the intersection of the result array column and triplet:

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

 val splitArr = udf { (s: String) => s.split(",").map(_.trim) }
val commonLen = udf { (a: WrappedArray[String], 
                       b: WrappedArray[String]) => a.intersect(b).length }

val temp = (result.withColumn("concatArr",
  splitArr(col("concatCSV"))).select(col("*"),
  commonLen(col("triplet"), col("concatArr")).alias("comm"))
  .filter(col("comm") >= 2)
  .drop("comm")
  .drop("concatArr"))

temp.show
+----+--------------------+----+--------------------+
|id_A|           concatCSV|id_B|             triplet|
+----+--------------------+----+--------------------+
|  14|StringD, StringB,...|  21|[StringA, StringF...|
|  18|StringA, StringB,...|  21|[StringA, StringF...|
+----+--------------------+----+--------------------+
Comments