SmallerThan SmallerThan - 3 years ago 92
Scala Question

spark custom Aggregator >=2.0 (scala)

I have following Dataset:

val myDS = List(("a",1,1.1), ("b",2,1.2), ("a",3,3.1), ("b",4,1.4), ("a",5,5.1)).toDS
// and aggregation
// myDS.groupByKey(t2 => t2._1).agg(myAvg).collect()


I want to write custom aggregate function
myAvg
which takes Tuple3 arguments and return
sum(_._2)/sum(_._3)
.
I know, that it can be computed in other ways, but I want to write custom aggregate.

I wrote something like that:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

val myAvg = new Aggregator[Tuple3[String, Integer, Double],
Tuple2[Integer,Double],
Double] {
def zero: Tuple2[Integer,Double] = Tuple2(0,0.0)
def reduce(agg: Tuple2[Integer,Double],
a: Tuple3[String, Integer,Double]): Tuple2[Integer,Double] =
Tuple2(agg._1 + a._2, agg._2 + a._3)
def merge(agg1: Tuple2[Integer,Double],
agg2: Tuple2[Integer,Double]): Tuple2[Integer,Double] =
Tuple2(agg1._1 + agg2._1, agg1._2 + agg2._2)
def finish(res: Tuple2[Integer,Double]): Double = res._1/res._2
def bufferEncoder: Encoder[(Integer, Double)] =
Encoders.tuple(Encoders.INT, Encoders.scalaDouble)
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}.toColumn()


Unfortunately I receive the following error:

java.lang.RuntimeException: Unsupported literal type class scala.runtime.BoxedUnit ()
at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:75)
at org.apache.spark.sql.functions$.lit(functions.scala:101)
at org.apache.spark.sql.Column.apply(Column.scala:217)


What's wrong?

In my local Spark 2.1 I receive one warning

warning: there was one deprecation warning; re-run with -deprecation for details


What is deprecated in my code?

Thanks for any advice.

Answer Source

It seems that the problem here is your use of Java's Integer instead of Scala's Int - if you replace all usages of Integer in your Aggregator implementation with Int (and replace Encoders.INT with Encoders.scalaInt) - this works as expected:

val myAvg: TypedColumn[(String, Int, Double), Double] =  new Aggregator[(String, Int, Double), (Int, Double), Double] {
  def zero: (Int, Double) = Tuple2(0,0.0)
  def reduce(agg: (Int, Double), a: (String, Int, Double)): (Int, Double) =
    (agg._1 + a._2, agg._2 + a._3)
  def merge(agg1: (Int, Double), agg2: (Int, Double)): (Int, Double) =
    (agg1._1 + agg2._1, agg1._2 + agg2._2)
  def finish(res: (Int, Double)): Double = res._1/res._2
  def bufferEncoder: Encoder[(Int, Double)] =
    Encoders.tuple(Encoders.scalaInt, Encoders.scalaDouble)
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}.toColumn

(also applied some syntactic sugar, removing explicit Tuble references).

Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download