kharandziuk kharandziuk - 1 year ago 41
Scala Question

Scala. Idiomatic flatten function implementation

I am trying to implement flatten function in Scala.
I finished with something like this:

// implementation
def flatten(xs: List[Any]): List[Any] =
xs match {
case List() => List()
case y::ys => y match {
case k::ks => flatten(List(k)) ::: flatten(ks) ::: flatten(ys)
case _ => y :: flatten(ys)
// something like tests
def main(args: Array[String]){
val f1 = flatten(List(List(1, 1), 2, List(3, List(5, 8))))
assert(f1 == List(1, 1, 2, 3, 5, 8))
val f2 = flatten(List(List(List(1), List(1)), 2, List(3, List(5, 8))))
assert(f2 == List(1, 1, 2, 3, 5, 8))

This implementation works but uses concatenation(it is slow I think). Can somebody provide(or explain) a solution without list concatenation?

I googled a little bit but most of question about built-in flatten

Answer Source

For starters, as @om-nom-nom pointed out, there is really no point in talking about anything being idiomatic without addressing the List[Any]. Let's see if we can describe this in a better way.

sealed trait Tree[A]
case class Node[A](l: List[Tree[A]]) extends Tree[A]
case class Leaf[A](a: A) extends Tree[A]

def flatten[A](tree: Tree[A]): List[A]

It becomes a bit easier to fill in the blanks now.

def flatten[A](tree: Tree[A]): List[A] = {
  def flattenRec(acc: List[A], t: Tree[A]): List[A] = t match {
    case Leaf(a) => a :: acc
    case Node(ll) => ll.foldLeft(acc)(flattenRec)
  flattenRec(Nil, tree).reverse

However, if we add some additional capability to our Tree using scalaz, then this becomes easier, and in fact may help you do whatever you wanted to do with the flattened list of lists. Here I am providing a definition of scalaz.Foldable[Tree].

import scalaz._
import Scalaz._

object Tree {
  implicit def treeFoldable = new Foldable[Tree] {
    override def foldMap[A, B](fa: Tree[A])(f: (A) => B)(implicit F: Monoid[B]): B = {
      fa match {
        case Leaf(a) => f(a)
        case Node(l) => l.foldLeft(, tree) => F.append(acc, foldMap(tree)(f)))

    override def foldRight[A, B](fa: Tree[A], z: => B)(f: (A, => B) => B): B = fa match {
      case Leaf(a) => f(a, z)
      case Node(l) => l.foldRight(z)((tree, zz) => foldRight[A, B](tree, zz)(f))

Now our flatten becomes simply

def flatten2[A](tree: Tree[A]): List[A] = {
  Foldable[Tree].foldLeft(tree, List.empty[A])((acc, a) => a :: acc).reverse

or using the foldable syntax imports

def flatten2[A](tree: Tree[A]): List[A] = {
  tree.foldLeft(List.empty[A])((acc, a) => a :: acc).reverse

If we had Tree[Int] we could sum all of the values

val numbers: Tree[Int] = Node(List(Leaf(1), Node(List(Leaf(2), Leaf(3))), Leaf(4)))
val sum = numbers.foldLeft(0)(_ +  _)

As it turns out, scalaz has a very similar Tree already, something I've found incredibly useful. The difference is that scalaz.Tree contains an A with each Node[A].