Roshan A Roshan A - 4 months ago 21
Scala Question

pass Array[seq[String]] to UDF in spark scala

I am new to UDF in spark. I have also read the answer here

Problem statement: trying to find pattern matching from a Dataframe col.

Ex:Dataframe

val df = Seq((1, Some("z")), (2, Some("abs,abc,dfg")),(3,Some("a,b,c,d,e,f,abs,abc,dfg"))).toDF("id", "text")


df.show()

+---+--------------------+
| id| text|
+---+--------------------+
| 1| z|
| 2| abs,abc,dfg|
| 3|a,b,c,d,e,f,abs,a...|
+---+--------------------+


df.filter($"text".contains("abs,abc,dfg")).count()
//returns 2 as abs exits in 2nd row and 3rd row


now I want to do this pattern matching for every row in column $text and add new column called count.
Result:

+---+--------------------+-----+
| id| text|count|
+---+--------------------+-----+
| 1| z| 1|
| 2| abs,abc,dfg| 2|
| 3|a,b,c,d,e,f,abs,a...| 1|
+---+--------------------+-----+


I tried to define a udf passing $text coloumn as Array[Seq[String]. but i am not able to get what I intended.

what I tried so far:

val txt = df.select("text").collect.map(_.toSeq.map(_.toString)) //convert column to Array[Seq[String]
val valsum = udf((txt:Array[Seq[String],pattern:String)=> {txt.count(_ == pattern) } )
df.withColumn("newCol", valsum( lit(txt) ,df(text)) )).show()


Any help would be appreciated

Answer Source

You will have to know all the elements of text column which can be done using collect_list by grouping all the rows of your dataframe as one. Then just check if element in text column in the collected array and count them as in the following code.

import sqlContext.implicits._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val df = Seq((1, Some("z")), (2, Some("abs,abc,dfg")),(3,Some("a,b,c,d,e,f,abs,abc,dfg"))).toDF("id", "text")

val valsum = udf((txt: String, array : mutable.WrappedArray[String])=> array.filter(element => element.contains(txt)).size)
df.withColumn("grouping", lit("g"))
  .withColumn("array", collect_list("text").over(Window.partitionBy("grouping")))
  .withColumn("count", valsum($"text", $"array"))
  .drop("grouping", "array")
  .show(false)

You should have following output

+---+-----------------------+-----+
|id |text                   |count|
+---+-----------------------+-----+
|1  |z                      |1    |
|2  |abs,abc,dfg            |2    |
|3  |a,b,c,d,e,f,abs,abc,dfg|1    |
+---+-----------------------+-----+

I hope this is helpful.