Dmitrii Dmitrii - 2 months ago 14
Scala Question

Counting layers in tree like structure

I've got a tree-like structure using Map:

val m = Map[Int, (Set[Int], Set[Int])]()


where node id is represented by id and each set is parents of the node and children respectively. And i'm trying to recursively count the amount of layers above and below the node. For instance i've got the tree like (0 - 1 - 2 - (3,4)) and i'm expecting some function to return me result as List of Sets where each sets is the layer of the tree. I've got the following method by which i'm gathering all parents

def p(n:Set[Int]):Set[Int] = if(n.isEmpty) Set.empty else n ++ m(n.head)._1 ++ p(n.tail)


but i'd like it to be grouped by corresponding levels of the tree so that i could get the desired result by calling size on it.

UPD:

m = Map(0 -> (Set(), Set(1), 1 -> (Set(0), Set(2,3)), 2 -> (Set(1), Set(4,5), 3 -> (Set(2), Set(6,7) ....)


this is how my map m could look like after filling up with tree nodes, and i want to get another Map from it which could look like:

Map(0 -> (List(Set()), List(Set(1), Set(2,3), Set(4,5,6,7)), 1 -> (List(Set(), Set(0)), List(Set(2,3), Set(4,5,6,7)) ... and so on)


That is i wanna have it grouped by each level having all parents layers in Sets and all children layers in Sets.

below is the simplified example:

val m = Map(2 -> (Set(1),Set(3, 4)), 4 -> (Set(2),Set()), 1 -> (Set(0),Set(2)), 3 -> (Set(2),Set()), 0 -> (Set(),Set(1)))


here is the tree of the following structure 0 - 1 - 2 - 3, 4

so here 0 is the root it has child one which in turn has child 2 which has 2 children 3 and 4. In more complex case node could have multiple parents but all of them are unique that's why i chose set, though it could be anything else, but with set i easily gather all parent nodes upward and all children downward, the only thing i want to get them grouped by level on which the resides. In this case the node 3 should have List(Set(2), Set(1), Set(0), Set()) as its parents.

Answer

BFS kind of traversal

Do a BFS kind of traversal and keep adding nodes to the map to correct level

BFS keeps a queue (using List as a queue in this code) and visits the tree/graph level by level. This is what we need.

One important point to note is how to keep track of end of the level. I keep track end of level using EndOfLevel

When you find EndOfLevel add another EndOfLevel if there are elements left in the queue if not say we are done and return the result.

sealed trait Node

  case class ANode(value: Int) extends Node

  case object EndOfLevel extends Node


  def bfs(root: Node, map: Map[Node, (Set[Node], Set[Node])]): List[(Int, Set[Node])] = {

     @tailrec
    def helper(queue: List[Node], level: Int, result: Map[Int, Set[Node]]): List[(Int, Set[Node])] = {

      if (queue.nonEmpty) {

        queue.head match {
          case anode@ANode(_) =>

            val newQueue = queue.tail ++ getNodes(anode, map)

            val newResult: Map[Int, Set[Node]] =
              if (result contains level) {
                result + (level -> (Set(anode) ++ result(level)))
              } else {
                result + (level -> Set(anode))
              }

            helper(newQueue, level, newResult)

          case EndOfLevel =>

            if (queue.tail.nonEmpty) helper(queue.tail ++ List(EndOfLevel), level + 1, result) else result

        }
      } else result
    }

    helper(List(root) ++ List(EndOfLevel), 0, Map(0 -> Set.empty[Node])).toList

  }

  def getNodes(node: Node, map: Map[Node, (Set[Node], Set[Node])]): Set[Node] = {
    val (left, right) = map.getOrElse(node, (Set.empty[Node], Set.empty[Node]))
    left ++ right
  }

Note that you can make your code more optimal using Vector instead of List .. Vector append is more performant than List

Running Code

sealed trait Node

case class ANode(value: Int) extends Node

case object EndOfLevel extends Node

object Main {

  def bfs(root: Node, map: Map[Node, (Set[Node], Set[Node])]): List[(Int, Set[Node])] = {
    def helper(queue: List[Node], level: Int, result: Map[Int, Set[Node]]): Map[Int, Set[Node]] = {
      if (queue.nonEmpty) {
        queue.head match {
          case anode@ANode(_) =>
            val newQueue = queue.tail ++ getNodes(anode, map)
            val newResult: Map[Int, Set[Node]] =
              if (result contains level) {
                result + (level -> (Set(anode) ++ result(level)))
              } else {
                result + (level -> Set(anode))
              }
            helper(newQueue, level, newResult)
          case EndOfLevel =>
            if (queue.tail.nonEmpty) helper(queue.tail ++ List(EndOfLevel), level + 1, result) else result
        }
      } else result
    }
    helper(List(root) ++ List(EndOfLevel), 0, Map(0 -> Set.empty[Node])).toList
  }


  def main(args: Array[String]): Unit = {
    val map: Map[Node, (Set[Node], Set[Node])] = Map(
      ANode(1) -> (Set[Node](ANode(2)) -> Set[Node](ANode(3))),
      ANode(2) -> (Set[Node](ANode(4)) -> Set[Node](ANode(5))),
      ANode(3) -> (Set[Node](ANode(6)) -> Set[Node](ANode(7)))
    )
    println(bfs(ANode(1), map))

  }


  def getNodes(node: Node, map: Map[Node, (Set[Node], Set[Node])]): Set[Node] = {
    val (left, right) = map.getOrElse(node, (Set.empty[Node], Set.empty[Node]))
    left ++ right
  }
}

Output

List((0,Set(ANode(1))), (1,Set(ANode(3), ANode(2))), (2,Set(ANode(7), ANode(6), ANode(5), ANode(4))))
Comments