Vektor88 Vektor88 - 6 days ago 7
Scala Question

Joining two Spark DataFrame according to size of intersection of two array columns

I have two

DataFrame
in my spark (v1.5.0) code:

aDF = [user_id : Int, user_purchases: array<int> ]
bDF = [user_id : Int, user_purchases: array<int> ]


What I want to do is to join these two dataframes, but I only need the lines where the intersection between
aDF.user_purchases
and
bDF.user_purchases
is greater than 2.

Do I have to use RDD API or is it possible to use some function from org.apache.sql.functions ?

Answer

I don't see any function built-in, but you can use UDF:

import scala.collection.mutable.WrappedArray;
val intersect = udf ((a : WrappedArray[Int], b : WrappedArray[Int]) => {
   var count = 0;
   a.foreach (x => {
       if (b.contains(x)) count = count + 1;
    });
    count;
});
// test data sets
val one = sc.parallelize(List(
        (1, Array(1, 2, 3)), 
        (2, Array(1,2 ,3, 4)), 
        (3, Array(1, 2,3)), 
        (4, Array(1,2))
        )).toDF("user", "arr");

val two = sc.parallelize(List(
        (1, Array(1, 2, 3)), 
        (2, Array(1,2 ,3, 4)), 
        (3, Array(1, 2, 3)), 
        (4, Array(1))
        )).toDF("user", "arr");

// usage
one.join(two, one("user") === two("user"))
    .select (one("user"), intersect(one("arr"), two("arr")).as("intersect"))
    .where(col("intersect") > 2).show

// version from comment
one.join(two)
    .select (one("user"), two("user"), intersect(one("arr"), two("arr")).as("intersect")).
    where('intersect > 2).show
Comments