Checking your tree balance in Scala
Scala is known for having multiple ways of solving the same problem. I would argue that this feature is not really Scala-specific, but let’s entertain this thought for a moment and deepen the stereotype even further.
Is your tree even balanced? 🌳
To explore this, we will use the famous binary tree balancing problem, or the leetcode 110:
Given a binary tree, determine whether it is height balanced, i.e. the depth of the two subtrees of each node never differs by more than one.
First, let’s define the input:
final case class TreeNode(
value: Int = 0,
left: TreeNode | Null = null,
right: TreeNode | Null = null
)
Yes, I used nulls in Scala. This is fine, because here I assume the use of Scala 3’s -Yexplicit-nulls compiler option, without it’s better to stick to the good old Option
.
Now, we would like something that implements following signature:
def isBalanced(root: TreenNode): Boolean = ???
Approach #1
When I hear both Tree
and Height
the first thing that comes to my mind is Depth-First-Search and some recursion. Let’s have a go at it:
object Approach1:
def isBalanced(root: TreeNode): Boolean =
def go(node: TreeNode | Null): (Boolean, Int) =
if (node == null)
then (true, 0)
else
val (leftBalanced, leftHeight) = go(node.left)
val (rightBalanced, rightHeight) = go(node.right)
val balanced = leftBalanced && rightBalanced && math.abs(leftHeight - rightHeight) <= 1
val height = math.max(leftHeight, rightHeight) + 1
(balanced, height)
go(root)._1
I think that’s a pretty neat solution - we only traverse the tree once and collect heights at the same time. The time complexity should be O(n) and the space complexity O(h).
There’s just one problem with this solution - it’s not stack-safe!
With some larger inputs, you might get a stack overflow error.
Approach #2
Ok, let’s try something different. If we want to avoid overflowing the stack, we can just trampoline the whole thing.
This technique in essentially moves the problem from the stack to the heap. Instead of doing the computation in place we suspend it in continuation and execute it later. In Scala we have a ready-made helper just for that in form of scala.util.control.TailCalls
. Let’s use it:
object Approach2:
import scala.util.control.TailCalls.*
def isBalanced(root: TreeNode): Boolean =
def go(node: TreeNode | Null): TailRec[Option[Int]] =
if (node == null)
then done(Some(0))
else if (node.left == null && node.right == null)
then done(Some(1))
else
tailcall(go(node.left)).flatMap:
case None => done(None)
case Some(leftHeight) =>
tailcall(go(node.right)).map:
case Some(rightHeight) if math.abs(leftHeight - rightHeight) <= 1 =>
Some(math.max(leftHeight, rightHeight) + 1)
case _ => None
go(root).result.isDefined
I like this approach too. We added few flatMap
s, but the general flow is the same.
As you might guess, this adds some runtime overhead - we are instantiating new objects left and right.
In most situations this is probably fine, but if you want more performance then you need to look elsewhere.
Approach #3
Ok, back to basics. Let’s do this in the crudest way possible by introducing some mutable state and imperative loops 🫣.
object Approach3:
import scala.collection.mutable.{Stack, Map}
def isBalanced(root: TreeNode): Boolean =
type Depth = Int
type Visited = Boolean
val stack = Stack[(TreeNode, Visited)]((root, false))
val depth = Map[TreeNode | Null, Depth]()
var result = true
var isFinished = false
while !isFinished do
if stack.isEmpty then isFinished = true
else
val (node, visited) = stack.pop()
if visited then
val left = depth.getOrElse(node.left, -1)
val right = depth.getOrElse(node.right, -1)
if math.abs(left - right) > 1 then
result = false
isFinished = true
else depth += node -> (math.max(left, right) + 1)
else
stack.push((node, true))
if node.left != null then stack.push((node.left.nn, false))
if node.right != null then stack.push((node.right.nn, false))
result
Here, instead of relying on an implicit function call stack, we make use an explicit (mutable) Stack. In this particular implementation, we visit each node twice. The first time to add all of its children to the Stack, the second time to actually inspect the height cached in the depth
Map. This could also be done in two separate (not nested) loops. This gives us O(2n) complexity, so… still O(n).
To break out of the loop we use isFinished
var. There are return
s in Scala, but they are frowned upon and deprecated. Scala 3 brings a new boundary / break mechanism that we could have used here instead, but I’m leave it as an exercise for the reader.
Though, it’s probably web-scale already I think we can improve upon it.
Approach #4
To be honest, the previous approach wasn’t that bad but I’m not the biggest fan of var
s if they can be avoided.
object Approach4:
import scala.collection.mutable.{Stack, Map}
def isBalanced(root: TreeNode): Boolean =
type Depth = Int
type Visited = Boolean
val stack = Stack[(TreeNode, Visited)]((root, false))
val depth = Map.empty[TreeNode | Null, Depth]
@scala.annotation.tailrec
def go(): Boolean =
if stack.isEmpty then true
else
val (node, visited) = stack.pop()
if visited then
val left = depth.getOrElse(node.left, -1)
val right = depth.getOrElse(node.right, -1)
if math.abs(left - right) > 1 then false
else
depth += node -> (math.max(left, right) + 1)
go()
else
stack.push((node, true))
if node.left != null then stack.push((node.left.nn, false))
if node.right != null then stack.push((node.right.nn, false))
go()
go()
This approach builds on top of the previous one. We just replaced the loop with a tail-recursive call and eliminated all the vars
. The mutable state in the form of mutable.Map
and mutable.Stack
is still there, but it’s local so I don’t mind it and I’d just ship it as is.
Approach #5
We have already moved from imperative loop to tail recursion, so why not go one step further and operate on an immutable Stack and Map.
object Approach5:
def isBalanced(root: TreeNode): Boolean =
type Depth = Int
@scala.annotation.tailrec
def go(stack: List[(TreeNode, Boolean)], depth: Map[TreeNode | Null, Depth]): Boolean =
stack match
case Nil => true
case (node, visited) :: tail =>
if visited then
val left = depth.getOrElse(node.left, -1)
val right = depth.getOrElse(node.right, -1)
if math.abs(left - right) > 1 then false
else go(tail, depth.updated(node, math.max(left, right) + 1))
else
go(
Option(node.right).toList.map(_.nn -> false) :::
Option(node.left).toList.map(_.nn -> false) :::
(node, true) ::
tail,
depth
)
go(
stack = (root, false) :: Nil,
depth = Map.empty[TreeNode | Null, Depth]
)
To still keep it tail-recursive we had to move the stack and depth map to go
s function params. The surprising thing might be the introduction of List instead of Stack. It turns out that the immutable stack is actually a List.
Now it’s fully immutable and tail-recursive.
Summary
Which approach is the best?
I don’t have a strong opinion on this. It’s all a matter of trade-offs, where you have to consider: ease/speed of development, maintainability and performance. My suggestion would be to use immutable, functional style by default, but go imperative when you have to.
Don’t be afraid of loops! 🔄