Silence of the λs

Algorithmic Tradeoffs In Scala

25 Aug 2013

As I'm prone to doing, several weeks ago I hijacked a discussion on HackerNews covering Go to steer the discussion towards Scala. The context was around Go's switch statement, which is somewhat more powerful than that found in C et al. I couldn't resist the opportunity to point out how awesome pattern matching in Scala is, noting that it is like "switch on steroids" according to scala-lang. The toy example I came up with centered around the problem of getting all leaf node values from a binary tree.

First take

sealed abstract class Node[A]
case class Fork[A](value: A, left: Node[A], right: Node[A]) extends Node[A]
case class Leaf[A](value: A) extends Node[A]

object Src {

  def getLeafNodeValues[A](node: Node[A]): List[A] = node match {
    case Fork(value, left, right) => getLeafNodeValues(left) ++ getLeafNodeValues(right)
    case Leaf(value) => List(value)
  }

}

If you're at all familiar with pattern matching, this is a pretty basic example. A node in a tree must be either a fork or a leaf, so getting leaf nodes is a simple task once we've determined what type of node we are dealing with. Pattern matching provides syntactic sugar around determining the type of a node and extracting the values from it into local variables. This is actually quite similar to an example in Twitter's Scala School, where they note that the implementation is "obviously correct".

Ok, neat; this is idiomatic functional code, but this is also exactly the sort of algorithm you will see demonstrating the dangers of recursion. On extremely large trees this will cause stack overflow, since we may recurse a very great number of times while exploring the tree.

The most interesting man in the world proving that nobody is above stack overflow

Tail call optimization

As is true for many other recursive data structures, recursion is a natural way to process a binary tree. It's certainly possible to do so iteratively, but some algorithms may be awkward or unidiomatic. Luckily, there's a way to use infinite recursion without blowing the stack; tail call optimization. In plain English, this means that the return value of a function consists solely of either a value, or the return value of another function call. If we make a recursive call that does not require preserving any state from the current function call, then theoretically we could reuse the stack frame allocated to this function call. A simple example of this can be seen in the factorial function:

// Not tail callable
def factorial(n: Int): Int = n match {
  case n if n <= 0 => 1
  case _ => n * factorial(n - 1)
}

// Tail callable
@tailrec
def betterFactorial(n: Int, acc: Int = 1): Int = n match {
  case n if n <= 0 => acc
  case _ => betterFactorial(n - 1, acc * n)
}

Writing a function that can be tail called essentially just requires passing the state which you would have preserved in local variables into the function. The JVM doesn't implement true tail call optimization, but the Scala compiler will attempt to optimize tail-recursive function calls into efficient loops. By annotating the function with @tailrec, we are asking the compiler to warn us if it is unable to do so.

Taking another swag at our leaf value collection function, we could write it like so:

import scala.annotations.tailrec

sealed abstract class Node[A]
case class Fork[A](value: A, left: Node[A], right: Node[A]) extends Node[A]
case class Leaf[A](value: A) extends Node[A]

object Src {

  def getLeafNodeValues(node: Node[A]): List[A] = getValuesFromList(List.empty[A], List(node))

  @tailrec
  def getValuesFromList(accValues: List[A], nodeList: List[Node[A]]): List[A] = nodeList match {
    case head :: tail => head match {
      case Fork(left, right) => getValuesFromList(accValues, left :: right :: tail)
      case Leaf(value) => getValuesFromList(value :: accValues, tail)
    }
    case _ => accValues
  }

}

While this code will not blow the stack, it's less obvious what is going on. An additional downside to this approach is that we are essentially duplicating the binary tree as a flattened list in memory, so our memory usage is roughly going to be linear with respect to the depth of the tree.

Here's where things get fuzzy

I began wondering if it was possible to get the best of both worlds: is there a way to collect all of the leaf nodes from a binary tree (as we've defined it here) in a tail call optimized manner while only using O(1) additional storage? Note: there is the unavoidable memory cost of gathering the leaf nodes, which is O(2d) in a perfectly balanced tree, with d being the depth.

The first thing that leapt to mind was an iterator. By creating a data structure that can traverse this tree on-demand, we could gather up the leaf nodes. Great. The problem is; how? Here was my line of reasoning:

  1. To implement an O(1) iterator, there must be a way to determine from a single node which the next node to traverse is. We need a getNext function.
  2. To implement this getNext function, we need some way of determining from a node which paths have already been traversed.
  3. One way of doing that is by preserving the list of traversed nodes, but that conflicts with the goal of using O(1) space.
  4. Another option is to use mutable state (modify the node indicating that it's been traversed somehow), but I'd prefer not to, since we only have two yucky choices from this:
    1. Make the leaf gathering function be non-referentially transparent, i.e. gathering the leaves permanently modifies the state of the tree.
    2. Walk the tree a second time to "reset" it.

Alec Baldwin scoffing
at the idea of a lack of referential transparency

The root of the problem is that we have to have a way to get from a child to its parent, and the only way (based on the current implementation) is to preserve that state explicitly.

Hard questions, hard answers

I took to StackOverflow to see if the wise folk there had some insights. We essentially ended up at a "short answer: 'Yes' with an 'If,' long answer: 'No' -- with a 'But.'".

Short answer: "yes, you can implement this best-of-both-worlds approach IF you have a link to the parent or modify the tree".

Long answer: "No, you cannot implement the algorithm as it stands, BUT you can get relatively efficient memory usage".

I'll admit to not understanding that last one...as is evident from my previous post, I am not exactly a Haskell wizard.

Taking the easy way out

If we modify the implementation such that every node has an optional link back to its parent (optional since the root won't have one, and null isn't idiomatic), we can mixin Traversable with Node, with the implementation of getNext provided seperately by each sub-type. The implementation could look like this:

EDIT: Updated below code, discovered a bug where the function passed to foreach was being applied multiple times to Forks.

sealed abstract class Node[A] extends Traversable[Node[A]] {
  def parent: Option[Node[A]]
  def value: A

  def getNextAndApply[A,B](prev: Node[A], f: Node[A] => B): Option[Node[A]]

  def foreach[B](f: Node[A] => B): Unit = applyForeach(Some(this), this, f)

  def applyForeach[B](current: Option[Node[A]], previous: Node[A], f: Node[A] => B): Unit = current match {
    case Some(n) => applyForeach(getNextAndApply(previous, f), n, f)
    case None => Unit
  }
}

case class Fork[A](value: A, left: Node[A], right: Node[A], parent: Option[Node[A]]) extends Node[A] {

  def getNextAndApply[A,B](prev: Node[A], f: Node[A] => B): Option[Node[A]] = {
    if (prev == left) {
      // f(this) here would result in in-order traversal
      Some(right)
    } else if (prev == right) {
      f(this) // post-order traversal
      parent
    } else {
      // f(this) here would result in pre-order traversal
      Some(left)
    }
  }

}

case class Leaf[A](value: A, parent: Option[Node[A]]) extends Node[A] {

  def getNextAndApply[A,B](prev: Node[A], f: Node[A] => B): Option[Node[A]] = {
    f(this) // This looks funny.
    parent
  }

}

object Src {

  def getLeafNodeValues[A](node: Node[A]): List[A] = node.filter {
    case _: Leaf[A] => true
    case _ => false
  }.map { _.value }.toList


}

One cool consequence of implementing Traversable is that this enables us to implement the leaf-node-value-gathering function via combinators rather than recursion. The actual grunt work of efficiently traversing the tree is always going to be the same, so this allows us to separate traversal from whatever else we want to do with the tree. Neat. =)

In addition, although this isn't implemented in this API, it would be trivial to pick between pre-order, in-order, and post-order traversal based on caller needs.

Conclusion

By modifying our implementation of binary trees, we were able to achieve a best-case scenario as far as traversal efficiency is concerned. In addition, we were able to restrict the harder-to-understand pieces of code into a single, abstractable unit. Without doing so, there was always going to be an inefficiency present, due to the nature of this problem.

Craig from South Park
indicating how happy he would be writing Scala all day.

...and Scala is awesome.