Vandexel - 1 year ago 100
Scala Question

Binary tree folding

I have been given the following definition of a binary tree

``````abstract class Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Node[A](value: A, left: Tree[A], right: Tree[A]) extends Tree[A]
``````

and the following function

``````def fold_tree[A,B](f1: A => B) (f2: (A,B,B) => B) (t: Tree[A]) : B =
t match {
case Leaf(value) => f1(value)
case Node(value , l, r) => f2 (value, fold_tree (f1) (f2) (l), fold_tree (f1) (f2) (r) ) //post order
``````

I have been asked to return the rightmost value in the tree using the fold method. How would this work? I'm not sure I fully understand fold, but I thought the point of fold was to do some operation to every element in the tree. How can I use it to just return the rightmost value of a tree?

I am also unsure as to how to call the method. I keep getting issues with unspecified parameter type. Can someone show me the proper way to call this fold_tree method?

How would this work? I'm not sure I fully understand fold

What your `foldTree` method is doing is recursively iterating the tree, applying itself to each `Node` or `Leaf` it encounters in the way. The method also has two type parameters that need to be applied, `A` and `B`, depending on the type of the provided tree.

Let's for the sake of the example say we have a `Tree[Int]` defined like this:

``````val n = Node(1, Leaf(2), Node(3, Leaf(5), Node(4, Leaf(42), Leaf(50))))
``````

The tree has a structure that looks like this:

We want to get the right most value, which is 50. In order to do that with the current implementation of `foldTree`, we need to provide it two methods:

1. `f1: A => B`: Given an `A`, project a `B` value
2. `f2: (A, B, B) => B`: Given one `A` and two `B` values, project a `B`.

We can see that `f1` is applied over a `Leaf`, and `f2` is applied over a `Node` (hence the different number of elements provided to each method). So this gives us a hint that the functions we provide to `foldTree` will be applied to each one, respectively.

Packed with that knowledge, given our tree:

``````val n = Node(1, Leaf(2), Node(3, Leaf(55), Node(4, Leaf(42), Leaf(50))))
``````

We provide the following method:

``````println(foldTree[Int, Int](identity)((x, y, z) => z)(n))
``````

What this means is as follows:

1. If we encounter a `Leaf` node and map over it (by applying `f1` on it), we simply want to extract it's value.
2. When encountering a `Node` element (and applying `f2` to it), we want to take the right most element, which is projected by the third element `z` in our method.

Running this, yields the expected result: `50`.

If we want to expand this generically for any `A`, we can say:

``````def extractRightmostValue[A](tree: Tree[A]) =
foldTree[A, A](identity)((x, y, z) => z)(tree)
``````
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download