likern likern - 10 months ago 37
Scala Question

Write function with type parameter

I have a unit test, which test some solution. But this test code can also be applied for testing the other, very similar solution. What I want to make is code of test be generic to be applied to both solutions, like this:

describe("when table contains all correct rows") {
it("should be empty") {
def check[T](func: T => List[Row]) = {
val tableGen = new TableGenerator()
val table: Vector[Row] = tableGen.randomTable(100)
.sortWith(_.time isBefore _.time).distinct
val result: List[Row] = func(table)


where solutions have types:

solution1: IndexedSeq[Row] => List[Row]
solution2: Seq[Row] => List[Row]

how check() function has to be written to be able to do that?
And what's the best approaches to write this (might be in other way) with eliminated code duplication?

When I try to compile this code I get type mismatch error in

Error:(36, 29) type mismatch;
found : table.type (with underlying type scala.collection.immutable.Vector[com.vmalov.tinkoff.Row])
required: T
val result = func(table)

Answer Source

For this to work, you need to be able to pass a Vector[Row] to func, so any Vector[Row] has to be a T; that is, T is a supertype of Vector[Row]. You can tell this to the compiler by using a type parameter bound:

def check[T >: Vector[Row]](func: T => List[Row])

Alternately, by the above reasoning, a function T => List[Row] will also be a function Vector[Row] => List[Row] precisely when T is a supertype of Vector[Row], and the Scala compiler knows about this (functions are contravariant in their argument type(s)). So this signature is equivalent to simpler

def check(func: Vector[Row] => List[Row])

Of course, you can generalize this, but how much depends on your specific desires. E.g. you can replace List[Row] with Seq[Row] (everywhere), or with a type parameter and pass an extra function to check:

def check[A](func: Vector[Row] => A)(test: A => Boolean) = {
  val table = ...
  val result = func(table)

check(Solution.solution1)(_.isEmpty) // the compiler infers A is List[Row]