newToScala newToScala - 3 months ago 7
Scala Question

Filtering a collection based on an arbitrary number of options

How can the following Scala function be refactored to use idiomatic best practices?

def getFilteredList(ids: Seq[Int],
idsMustBeInThisListIfItExists: Option[Seq[Int]],
idsMustAlsoBeInThisListIfItExists: Option[Seq[Int]]): Seq[Int] = {

var output = ids

if (idsMustBeInThisListIfItExists.isDefined) {
output = output.intersect(idsMustBeInThisListIfItExists.get)
}
if (idsMustAlsoBeInThisListIfItExists.isDefined) {
output = output.intersect(idsMustAlsoBeInThisListIfItExists.get)
}

output
}


Expected IO:

val ids = Seq(1,2,3,4,5)
val output1 = getFilteredList(ids, None, Some(Seq(3,5))) // 3, 5
val output2 = getFilteredList(ids, None, None) // 1,2,3,4,5
val output3 = getFilteredList(ids, Some(Seq(1,2)), None) // 1,2
val output4 = getFilteredList(ids, Some(Seq(1)), Some(Seq(5))) // 1,5


Thank you for your time.

Answer

Here's a simple way to do this:

  implicit class SeqAugmenter[T](val seq: Seq[T]) extends AnyVal {
    def intersect(opt: Option[Seq[T]]): Seq[T] = {
      opt.fold(seq)(seq intersect _)
    }
  }

  def getFilteredList(ids: Seq[Int],
    idsMustBeInThisListIfItExists: Option[Seq[Int]],
    idsMustAlsoBeInThisListIfItExists: Option[Seq[Int]]
  ): Seq[Int] = {
    ids intersect
      idsMustBeInThisListIfItExists intersect 
      idsMustAlsoBeInThisListIfItExists
  }
Comments