Skip to content
GitHub

Checking your tree balance in Scala

Scala is known for having multiple ways of solving the same problem. I would argue that this feature is not really Scala-specific, but let’s entertain this thought for a moment and deepen the stereotype even further.


Is your tree even balanced? 🌳

To explore this, we will use the famous binary tree balancing problem, or the leetcode 110:

Given a binary tree, determine whether it is height balanced, i.e. the depth of the two subtrees of each node never differs by more than one.

First, let’s define the input:

final case class TreeNode(
    value: Int = 0, 
    left: TreeNode | Null = null, 
    right: TreeNode | Null = null
)

Yes, I used nulls in Scala. This is fine, because here I assume the use of Scala 3’s -Yexplicit-nulls compiler option, without it’s better to stick to the good old Option.

Now, we would like something that implements following signature:

def isBalanced(root: TreenNode): Boolean = ???

Approach #1

When I hear both Tree and Height the first thing that comes to my mind is Depth-First-Search and some recursion. Let’s have a go at it:

object Approach1:
  def isBalanced(root: TreeNode): Boolean =
    def go(node: TreeNode | Null): (Boolean, Int) =
      if (node == null)
      then (true, 0)
      else
        val (leftBalanced, leftHeight) = go(node.left)
        val (rightBalanced, rightHeight) = go(node.right)
        val balanced = leftBalanced && rightBalanced && math.abs(leftHeight - rightHeight) <= 1
        val height = math.max(leftHeight, rightHeight) + 1
        (balanced, height)

    go(root)._1

I think that’s a pretty neat solution - we only traverse the tree once and collect heights at the same time. The time complexity should be O(n) and the space complexity O(h).

There’s just one problem with this solution - it’s not stack-safe!

With some larger inputs, you might get a stack overflow error.

Approach #2

Ok, let’s try something different. If we want to avoid overflowing the stack, we can just trampoline the whole thing.

This technique in essentially moves the problem from the stack to the heap. Instead of doing the computation in place we suspend it in continuation and execute it later. In Scala we have a ready-made helper just for that in form of scala.util.control.TailCalls. Let’s use it:

object Approach2:
  import scala.util.control.TailCalls.*

  def isBalanced(root: TreeNode): Boolean =
    def go(node: TreeNode | Null): TailRec[Option[Int]] =
        if (node == null)
        then done(Some(0))
        else if (node.left == null && node.right == null)
        then done(Some(1))
        else
          tailcall(go(node.left)).flatMap:
              case None => done(None)
              case Some(leftHeight) =>
                tailcall(go(node.right)).map:
                    case Some(rightHeight) if math.abs(leftHeight - rightHeight) <= 1 =>
                        Some(math.max(leftHeight, rightHeight) + 1)
                    case _ => None

    go(root).result.isDefined

I like this approach too. We added few flatMaps, but the general flow is the same. As you might guess, this adds some runtime overhead - we are instantiating new objects left and right. In most situations this is probably fine, but if you want more performance then you need to look elsewhere.

Approach #3

Ok, back to basics. Let’s do this in the crudest way possible by introducing some mutable state and imperative loops 🫣.

object Approach3:
  import scala.collection.mutable.{Stack, Map}

  def isBalanced(root: TreeNode): Boolean =
    type Depth = Int
    type Visited = Boolean
    val stack = Stack[(TreeNode, Visited)]((root, false))
    val depth = Map[TreeNode | Null, Depth]()

    var result = true
    var isFinished = false

    while !isFinished do
      if stack.isEmpty then isFinished = true
      else
        val (node, visited) = stack.pop()

        if visited then
          val left = depth.getOrElse(node.left, -1)
          val right = depth.getOrElse(node.right, -1)
          if math.abs(left - right) > 1 then
            result = false
            isFinished = true
          else depth += node -> (math.max(left, right) + 1)
        else
          stack.push((node, true))
          if node.left != null then stack.push((node.left.nn, false))
          if node.right != null then stack.push((node.right.nn, false))

    result

Here, instead of relying on an implicit function call stack, we make use an explicit (mutable) Stack. In this particular implementation, we visit each node twice. The first time to add all of its children to the Stack, the second time to actually inspect the height cached in the depth Map. This could also be done in two separate (not nested) loops. This gives us O(2n) complexity, so… still O(n).

To break out of the loop we use isFinished var. There are returns in Scala, but they are frowned upon and deprecated. Scala 3 brings a new boundary / break mechanism that we could have used here instead, but I’m leave it as an exercise for the reader.

Though, it’s probably web-scale already I think we can improve upon it.

Approach #4

To be honest, the previous approach wasn’t that bad but I’m not the biggest fan of vars if they can be avoided.

object Approach4:
  import scala.collection.mutable.{Stack, Map}

  def isBalanced(root: TreeNode): Boolean =
    type Depth = Int
    type Visited = Boolean
    val stack = Stack[(TreeNode, Visited)]((root, false))
    val depth = Map.empty[TreeNode | Null, Depth]

    @scala.annotation.tailrec
    def go(): Boolean =
      if stack.isEmpty then true
      else
        val (node, visited) = stack.pop()

        if visited then
          val left = depth.getOrElse(node.left, -1)
          val right = depth.getOrElse(node.right, -1)

          if math.abs(left - right) > 1 then false
          else
            depth += node -> (math.max(left, right) + 1)
            go()
        else
          stack.push((node, true))
          if node.left != null then stack.push((node.left.nn, false))
          if node.right != null then stack.push((node.right.nn, false))
          go()

    go()

This approach builds on top of the previous one. We just replaced the loop with a tail-recursive call and eliminated all the vars. The mutable state in the form of mutable.Map and mutable.Stack is still there, but it’s local so I don’t mind it and I’d just ship it as is.

Approach #5

We have already moved from imperative loop to tail recursion, so why not go one step further and operate on an immutable Stack and Map.

object Approach5:
  def isBalanced(root: TreeNode): Boolean =
    type Depth = Int

    @scala.annotation.tailrec
    def go(stack: List[(TreeNode, Boolean)], depth: Map[TreeNode | Null, Depth]): Boolean =
      stack match
        case Nil => true
        case (node, visited) :: tail =>
          if visited then
            val left = depth.getOrElse(node.left, -1)
            val right = depth.getOrElse(node.right, -1)

            if math.abs(left - right) > 1 then false
            else go(tail, depth.updated(node, math.max(left, right) + 1))
          else
            go(
              Option(node.right).toList.map(_.nn -> false) :::
                Option(node.left).toList.map(_.nn -> false) :::
                (node, true) ::
                tail,
              depth
            )

    go(
      stack = (root, false) :: Nil,
      depth = Map.empty[TreeNode | Null, Depth]
    )

To still keep it tail-recursive we had to move the stack and depth map to gos function params. The surprising thing might be the introduction of List instead of Stack. It turns out that the immutable stack is actually a List.

Now it’s fully immutable and tail-recursive.

Summary

Which approach is the best?

I don’t have a strong opinion on this. It’s all a matter of trade-offs, where you have to consider: ease/speed of development, maintainability and performance. My suggestion would be to use immutable, functional style by default, but go imperative when you have to.

Don’t be afraid of loops! 🔄