Vandexel Vandexel - 24 days ago 5x
Scala Question

Binary tree folding

I have been given the following definition of a binary tree

abstract class Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Node[A](value: A, left: Tree[A], right: Tree[A]) extends Tree[A]

and the following function

def fold_tree[A,B](f1: A => B) (f2: (A,B,B) => B) (t: Tree[A]) : B =
t match {
case Leaf(value) => f1(value)
case Node(value , l, r) => f2 (value, fold_tree (f1) (f2) (l), fold_tree (f1) (f2) (r) ) //post order

I have been asked to return the rightmost value in the tree using the fold method. How would this work? I'm not sure I fully understand fold, but I thought the point of fold was to do some operation to every element in the tree. How can I use it to just return the rightmost value of a tree?

I am also unsure as to how to call the method. I keep getting issues with unspecified parameter type. Can someone show me the proper way to call this fold_tree method?


How would this work? I'm not sure I fully understand fold

What your foldTree method is doing is recursively iterating the tree, applying itself to each Node or Leaf it encounters in the way. The method also has two type parameters that need to be applied, A and B, depending on the type of the provided tree.

Let's for the sake of the example say we have a Tree[Int] defined like this:

val n = Node(1, Leaf(2), Node(3, Leaf(5), Node(4, Leaf(42), Leaf(50))))

The tree has a structure that looks like this:


We want to get the right most value, which is 50. In order to do that with the current implementation of foldTree, we need to provide it two methods:

  1. f1: A => B: Given an A, project a B value
  2. f2: (A, B, B) => B: Given one A and two B values, project a B.

We can see that f1 is applied over a Leaf, and f2 is applied over a Node (hence the different number of elements provided to each method). So this gives us a hint that the functions we provide to foldTree will be applied to each one, respectively.

Packed with that knowledge, given our tree:

val n = Node(1, Leaf(2), Node(3, Leaf(55), Node(4, Leaf(42), Leaf(50))))

We provide the following method:

println(foldTree[Int, Int](identity)((x, y, z) => z)(n))

What this means is as follows:

  1. If we encounter a Leaf node and map over it (by applying f1 on it), we simply want to extract it's value.
  2. When encountering a Node element (and applying f2 to it), we want to take the right most element, which is projected by the third element z in our method.

Running this, yields the expected result: 50.

If we want to expand this generically for any A, we can say:

def extractRightmostValue[A](tree: Tree[A]) =
  foldTree[A, A](identity)((x, y, z) => z)(tree)