antikantian antikantian - 2 months ago 11
Scala Question

Is there a better way to get around type erasure on collections, specifically in the case of passing said collection to a Java method?

Due to the projects that I work on, I don't deal with type erasure all that much. That said, here's one method that just bothers the hell out of me, and I'm kind of stuck coming up with an alternative solution. I'm working on a project that uses a lot of matrix multiplications, and I'm using fommil's netlib-java for native blas operations. Here's the method in question:

def gemm[A: ClassTag: TypeTag](
transA : String,
transB : String,
m : Int,
n : Int,
k : Int,
alpha : A,
a : Array[A],
b : Array[A],
beta : A) = {

val lda = if (transA == "N" || transA == "n") k else m
val ldb = if (transB == "N" || transA == "n") n else k

typeOf[A] match {
case t if t =:= typeOf[Float] =>
val _alpha = alpha.asInstanceOf[Float]
val _beta = beta.asInstanceOf[Float]
val _a = a.asInstanceOf[Array[Float]]
val _b = b.asInstanceOf[Array[Float]]
val outArray = new Array[Float](m * n)
blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
outArray.asInstanceOf[Array[A]]
case t if t =:= typeOf[Double] =>
val _alpha = alpha.asInstanceOf[Double]
val _beta = beta.asInstanceOf[Double]
val _a = a.asInstanceOf[Array[Double]]
val _b = b.asInstanceOf[Array[Double]]
val outArray = new Array[Double](m * n)
blas.dgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
outArray.asInstanceOf[Array[A]]
case _ =>
val outArray = Predef.implicitly[ClassTag[A]].newArray(m * n)
gemm_ref(transA, transB, m, n, k, alpha, a, b, beta, outArray)
outArray
}
}


An alternative that I've considered is type safe cast with Shapeless' Typeable/Typecase. From my understanding, this works by going through each element in the collection to ensure type-uniformity. There is overhead associated with this, and, since I'm dealing with arrays that often have a lot of elements, I'm not looking for any additional overhead.

Answer

How about something like this?

trait Blas[A] {
  def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: A, beta: A, a: Array[A], b: Array[A]): Array[A]
}

object Blas {
  implicit def floatBlas: Blas[Float] = new Blas[Float] {
    override def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: Float, beta: Float, a: Array[Float], b: Array[Float]): Array[Float] = {
      val outArray = new Array[Float](m * n)
      blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
      outArray
    }
  }

  implicit def doubleBlas: Blas[Double] = ???

  // etc.
}

def gemm[A](
    transA: String,
    transB: String,
    m: Int,
    n: Int,
    k: Int,
    alpha: A,
    a: Array[A],
    b: Array[A],
    beta: A
)(implicit blas: Blas[A]) = {

  val lda = if (transA == "N" || transA == "n") k else m
  val ldb = if (transB == "N" || transA == "n") n else k

  blas.gemm(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, m)

}

(You'll have to fix the object and method names yourself, I don't know what they're referring to.)

The idea is that you pass an additional implicit parameter, lookup up automatically. And while defining those instances, you have the full type information available and then don't need to match on typeOf.