Leo-T Leo-T - 28 days ago 15
Scala Question

Spark Scala: Pass a sub type to a function accepting the parent type

Suppose I have an abstract class

A
. I also have classes
B
and
C
which inherit from class
A
.

abstract class A {
def x: Int
}
case class B(i: Int) extends A {
override def x = -i
}
case class C(i: Int) extends A {
override def x = i
}


Given these classes, I construct the following RDD:

val data = sc.parallelize(Seq(
Set(B(1), B(2)),
Set(B(1), B(3)),
Set(B(1), B(5))
)).cache
.zipWithIndex
.map {case(k, v) => (v, k)}


I also have the following function that gets an RDD as the input and returns the count of each element:

def f(data: RDD[(Long, Set[A])]) = {
data.flatMap({
case (k, v) => v map { af =>
(af, 1)
}
}).reduceByKey(_ + _)
}


Note that the RDD is accepting type
A
. Now, I expect
val x = f(data)
to return the counts as expected, as
B
is a sub-type of
A
, but I get the following compile error:

type mismatch;
found : org.apache.spark.rdd.RDD[(Long, scala.collection.immutable.Set[B])]
required: org.apache.spark.rdd.RDD[(Long, Set[A])]
val x = f(data)


This error goes away if I change the function signature to
f(data: RDD[(Long, Set[B])])
; however, I can't do that as I want to use other sub classes in the RDD (like
C
).

I have also tried the following approach:

def f[T <: A](data: RDD[(Long, Set[T])]) = {
data.flatMap({
case (k, v) => v map { af =>
(af, 1)
}
}) reduceByKey(_ + _)
}


However, this also gives me the following run-time error:

value reduceByKey is not a member of org.apache.spark.rdd.RDD[(T, Int)]
possible cause: maybe a semicolon is missing before `value reduceByKey'?
}) reduceByKey(_ + _)


I appreciate any help on this.

Answer

Set[T] is invariant on T, meaning that given A subtype of B, Set[A] is not a subtype nor a supertype of Set[B] RDD[T] is also invariant on T further restricting the options because, even if a covariant Collection[+T] is used (e.g. a List[+T]) the same situation will arise.

We can resort to a polymorphic form of the method for an alternative: What's missing in the version above is a ClassTag that Spark requires to preserve class information after erasure.

This should work:

import scala.reflect.{ClassTag}
def f[T:ClassTag](data: RDD[(Long, Set[T])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }) reduceByKey(_ + _)
}

Let's see:

val intRdd = sparkContext.parallelize(Seq((1l, Set(1,2,3)), (2L, Set(4,5,6))))
val res1= f(intRdd).collect
// Array[(Int, Int)] = Array((4,1), (1,1), (5,1), (6,1), (2,1), (3,1))

val strRdd = sparkContext.parallelize(Seq((1l, Set("a","b","c")), (2L, Set("d","e","f"))))
val res2 = f(strRdd).collect
// Array[(String, Int)] = Array((d,1), (e,1), (a,1), (b,1), (f,1), (c,1))