proto-n proto-n - 2 months ago 5
Scala Question

Override implicit parameter in a method called on an object, with the ability to manipulate this object


edit: I edited the title to discourage people from advising me
to explicitly provide the implicit parameter


I've been struggling with the below example, and wasn't able to find a solution for a while now. I've extracted the scheme of the problem from a bigger project (flink-ml) but it's a scala problem.

The setup



Basically, I have the following setup for the part I'm dealing with:

//design to fit multiple models
trait PredictOperation[Self, InputType, OutputType] {
def doPrediction(input: InputType): OutputType
}

class Predictor[Model] {
def predict[InputType, OutputType](
param1: InputType)(implicit
predictOperation: PredictOperation[Model, InputType, OutputType])
: OutputType =
{
predictOperation.doPrediction(param1)
}
}


//a particular model that uses the above design
class SampleModel extends Predictor[SampleModel] {
//...
}
object SampleModel {
implicit val samplePredictOperation = new PredictOperation[SampleModel, Int, Double]{
override def doPrediction(input:Int):Double = 2.0*input
}
}

//some object using this model
object SomeObjectUsingModelPredict{
def SomeFunctionUsingModelPredict(m : SampleModel): Unit ={
println(m.predict(3))
}
}


I use
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict
in my code:

//I want to override the doPrediction function of the model
object MainObject{
def main(args: Array[String]): Unit ={
val m = new SampleModel
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
//outputs 6.0
}
}


The goal



I want to override the
doPrediction
function of SampleModel and hand
SomeFunctionUsingModelPredict
an object that behaves according to this overridden behavior (outputs 9.0 for this particular example).

Attempts



I've tried a few things, none of which work:

//method 'predict' overrides nothing
object MainObject{
def main(args: Array[String]): Unit ={
val m = new SampleModel{
override def predict(param1: Int)(implicit predictOperation: PredictOperation[SampleModel, Int, Double]): Double = 3*param1
}
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
}
}

//outputs 6.0
object MainObject{'predict' overrides nothing
def main(args: Array[String]): Unit ={
implicit val samplePredictOperation = new PredictOperation[SampleModel, Int, Double]{
override def doPrediction(input:Int):Double = 3.0*input
}
val m = new SampleModel
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
}
}

//outputs 6.0
object MainObject{
def main(args: Array[String]): Unit ={
val m = new SampleModel{
implicit val samplePredictOperation = new PredictOperation[SampleModel, Int, Double]{
override def doPrediction(input:Int):Double = 3.0*input
}
override def predict[InputType, OutputType](param1: InputType)(implicit predictOperation: PredictOperation[SampleModel, InputType, OutputType]): OutputType = super.predict(param1)(predictOperation)
}
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
}
}

//outputs 6.0
object MainObject{
class ExtSampleModel extends SampleModel {
override def predict[InputType, OutputType](param1: InputType)(implicit predictOperation: PredictOperation[SampleModel, InputType, OutputType]): OutputType = super.predict(param1)(predictOperation)
}
object ExtSampleModel{
implicit val samplePredictOperation = new PredictOperation[SampleModel, Int, Double]{
override def doPrediction(input:Int):Double = 3.0*input
}
}
def main(args: Array[String]): Unit ={
val m = new ExtSampleModel
SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
}
}


Question



Is it possible to accomplish what I'm trying to do? Ideas for refactoring are always welcome, but I'm primarily curious about this for the sake of the problem itself.

I can't really touch much of this design, at the moment I've worked around the problem by refactoring
SomeObjectUsingModelPredict
.

Answer

If you want a dirty hack you can do it like this. First you have to bypass the implicit parameter by throwing away the PredictOperation that your predict method receives and replacing it with your own. And then you have to bypass the type system with some casts.

object MainObject{
  class ExtSampleModel extends SampleModel {
    private val samplePredictOperation = new PredictOperation[SampleModel, Int, Double]{
      override def doPrediction(input:Int):Double = 3.0*input
    }

    override def predict[InputType, OutputType](param1: InputType)(implicit predictOperation: PredictOperation[SampleModel, InputType, OutputType]): OutputType = 
      super.predict(param1.asInstanceOf[Int])(samplePredictOperation).asInstanceOf[OutputType]
  }

  def main(args: Array[String]): Unit ={
    val m = new ExtSampleModel
    SomeObjectUsingModelPredict.SomeFunctionUsingModelPredict(m)
  }
}