# Trampolining and stack safety in Scala

## Making every call a self recursive tail call

## Problem

Scala compiler isn’t great with tail call elimination in general. This makes functions composed of many smaller functions prone to stack overflows.

```
def even[A](lst: List[A]): Boolean = {
lst match {
case Nil => true
case x :: xs => odd(xs)
}
}
def odd[A](lst: List[A]): Boolean = {
lst match {
case Nil => false
case x :: xs => even(xs)
}
}
even((0 to 1000000).toList) // blows the stack
```

## Solution

The Scala compiler is able to optimize a specific kind of tail call known as a self-recursive call:

```
def gcd(a: Int, b: Int): Int =
if (b == 0) a else gcd(b, a % b)
```

We will try to find a way to transform every call into a self-recursive call that compiler will optimize.

## Thoughts

As in example above, multiple functions can be invoked in a recursive process but we can have only one self-recursive function to make it stack safe. We need to change our functions to build a description of a program instead of actually making any recursive calls. Then, we would need some sort of generic function that will go through our description and evaluate each recursion step.

## Approach #1

Let’s start with algebraic data type that would represent our recursive program:

```
sealed trait Trampoline[A]
case class Done[A](value: A) extends Trampoline[A]
case class More[A](call: () => Trampoline[A]) extends Trampoline[A]
```

**Trampoline** is the term for this technique, think about it as a *program description*.
The program description can be either:

`Done`

if there are no computations to be done and we can yield a value`More`

if there is a recursive function call to be made

Here’s an example of a simple program that has 2 suspended computations and yields value 42

```
More(() => More(() => Done(42)))
```

Of course we won’t create it manually, our recursive functions will:

```
def even[A](lst: Seq[A]): Trampoline[Boolean] = {
lst match {
case Nil => Done(true)
case x :: xs => More(() => odd(xs))
}
}
def odd[A](lst: Seq[A]): Trampoline[Boolean] = {
lst match {
case Nil => Done(false)
case x :: xs => More(() => even(xs))
}
}
```

So now when we call `even(0 to 1000000)`

we get a value with type `More[Boolean]`

containing a function to compute `odd`

for a list length `999999`

. To get the actual boolean value we need to traverse to the bottom of all nested `More`

s to finally reach `Done`

. Let’s write a function for that:

```
def run[A](trampoline: Trampoline[A]): A = {
trampoline match {
case Done(v) => v
case More(t) => run(t()) // <- tail recursive, yay
}
}
```

`Done`

case is straightforward, `More`

case has to first execute our suspended function to get either `More[A]`

or `Done[A]`

and recursively call run on that. This kind if recursion is safe and will be optimized by the compiler.

One small refactoring that looks unnecessary but it will be useful a bit later — separate executing a single step from executing all steps:

```
def resume[A](t: Trampoline[A]): Either[() => Trampoline[A], A] = t match {
case Done(v) => Right(v)
case More(k) => Left(k)
}
def run[A](t: Trampoline[A]): A = resume(t) match {
case Right(value) => value
case Left(more) => run(more())
}
```

And also for convenience we can embed these functions into `Trampoline`

trait:

```
sealed trait Trampoline[+A] {
def resume: Either[() => Trampoline[A], A] = this match {
case Done(v) => Right(v)
case More(k) => Left(k)
}
final def runT: A = resume match {
case Right(value) => value
case Left(more) => more().runT
}
}
```

There you go, stack safe computations ftw.

## Still having stack overflows…

Let’s say you have a State monad implemented like this:

```
case class State[S, +A](runS: S => (S, A)) {
def map[B](f: A => B): State[S, B] =
State[S, B](s => {
val (s1, a) = runS(s)
(s1, f(a))
})
def flatMap[B](f: A => State[S, B]): State[S, B] =
State[S, B](s => {
val (s1, a) = runS(s)
f(a) runS s1
})
}
```

And a usage example:

```
def zipIndex[A](as: Seq[A]): State[Int, List[(Int, A)]] =
as.foldLeft(pureState[Int, List[(Int,A)]](List()))((acc,a) => for {
xs <- acc
n <- getState
_ <- setState(n + 1)
} yield (n, a) :: xs)
zipIndex(0 to 1000000).runS(0) // blows the stack
```

Looks innocent but what we don’t see is many `flatMap`

calls chained together consuming all stack memory during execution.

```
flatMap(flatMap(flatMap(...
```

Different use case but the same problem: `flatMap`

is not self-recursive tail call. Can we make it one? Let’s take `flatMap`

implementation, do the same mechanical transformation (as we did for `even`

and `odd`

example) by changing return type to be `Done`

or `More`

:

```
case class State[S, +A](runS: S => Trampoline[(S, A)]) {
def flatMap[B](f: A => State[S,B]): State[S, B] =
State[S,B](s => More(() => {
val (s1, a) = runS(s).runT // <- not in the tail position
More(() => f(a) runS s1)
}))
}
```

`runS`

has to be executed one the state `s`

*before* returning `More`

meaning that our trampolining trick would not work.

## Approach #2

Instead of trying to fix `flatMap`

’s implementation we can just suspend the whole `flatMap`

operation on the `State`

and defer the binding until our `run`

and `resume`

functions will be called.

```
sealed trait Trampoline[A]
case class Done[A](value: A) extends Trampoline[A]
case class More[A](call: () => Trampoline[A]) extends Trampoline[A]
case class FlatMap[A, B](
sub: Trampoline[A],
cont: A => Trampoline[B]) extends Trampoline[B]
```

`FlatMap`

is basically a signature of `flatMap`

stored in a data structure. After that State’s `map`

and `flatMap`

would look like this:

```
case class State[S, +A](runS: S => Trampoline[(S, A)]) {
def map[B](f: A => B): State[S, B] = State[S, B](
runS andThen { tramp => {
val (s, a) = tramp.runT
Done((s, f(a)))
}}
)
def flatMap[B](f: A => State[S, B]): State[S, B] = State[S, B](
runS andThen { tramp => {
FlatMap[(S, A), (S, B)](tramp, { case (s, a) => f(a).runS(s) })
}}
)
}
```

This trick allowed us to change nested `flatMap(flatMap(flatMap...`

calls to a data structure `FlatMap(FlatMap(FlatMap(...`

trading stack for heap.

Now we should be able to compose our `“trampolined”`

States and safely use “for” comprehensions as in `zipIndex`

example. One thing left is to add newly added `FlatMap`

case to our tail recursive interpreter function:

```
sealed trait Trampoline[+A] {
def resume: Either[() => Trampoline[A], A] = this match {
case Done(v) => Right(v)
case More(k) => Left(k)
case FlatMap(sub, cont) => sub match {
case Done(v) => cont(v).resume
case More(k) => Left(() => FlatMap(k(), cont))
case FlatMap(sub2, cont2) =>
(FlatMap(sub2, (x:Any) => FlatMap(cont2(x), cont)):Trampoline[A]).resume
}
}
final def runT: A = resume match {
case Right(value) => value
case Left(more) => more().runT
}
}
```

So the `FlatMap`

having `sub`

(sub expression) and `cont`

(continuation function) is about matching on the `sub`

:

`Done(v)`

means sub is the last in the chain of`FlatMaps`

and we just have to thread`v`

through`cont`

and call`resume`

`More(k)`

(where`k`

is a suspended function) means we can make a step by calling`k`

, making the resulting trampoline a`sub`

expression of a new`FlatMap`

with the same continuation- On
`FlatMap(sub2, cont2)`

we will create a`FlatMap`

with the next`sub2`

expression. Its continuation will create another`FlatMap`

with evaluated inner`cont2`

.

Also, we must cast explicitly to `Trampoline[A]`

for the compiler to be able to figure out that this is in fact a tail-recursive self-call.

## Still having stack overflows…

Even with our improvements it is possible to blow a stack in cases when `FlatMap`

nesting is too deep: evaluating a `sub`

to make a step can cause another `sub`

to be evaluated and so on. So let’s not directly create deeply nested `FlatMaps`

(like in State’s `flatMap`

) but use a helper function that will do the following:

```
sealed trait Trampoline[+A] {
/* resume and runT */
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = this match {
case x => FlatMap(x, f)
case FlatMap(sub, k) => FlatMap(sub, (x: Any) => k(x).flatMap(f))
}
}
```

Yes, this turns out to be `Trampoline`

’s `flatMap`

. If its called on either `Done`

or `More`

we will just wrap it in a `FlatMap`

. If its called on the `FlatMap`

then we will reassociate the bind to the right.

With this change in place let’s refactor `State`

’s `flatMap`

to use `Trampoline`

’s `flatMap`

:

```
case class State[S, +A](runS: S => Trampoline[(S, A)]) {
def map[B](f: A => B): State[S, B] = State[S, B](
runS andThen { tramp => {
val (s, a) = tramp.runT
Done((s, f(a)))
}}
)
def flatMap[B](f: A => State[S, B]): State[S, B] = State[S, B](
runS.andThen { (tramp: Trampoline[(S, A)]) => {
tramp.flatMap { case (s, a) => f(a).runS(s) }
}}
)
}
```

Also, our interpreter function can also create deeply nested `FlatMaps`

so we need to use `Trampoline`

’s `flatMap`

as well:

```
sealed trait Trampoline[+A] {
def resume: Either[() => Trampoline[A], A] = this match {
case Done(v) => Right(v)
case More(k) => Left(k)
case FlatMap(sub, cont) => sub match {
case Done(v) => cont(v).resume
case More(k) => Left(() => FlatMap(k(), cont))
case FlatMap(sub2, cont2) =>
sub2.flatMap((x:Any) => cont2(x).flatMap(cont)).resume // <-
}
}
/* runT, flatMap */
}
```

## Putting it all together

We started with data structures `Done`

and `More`

to represent program description and used self-recursive interpreter function to avoid stack overflows. It was suitable for simple use cases but was insufficient for more complex cases, like nested monadic flatMaps. We represented a `flapMap`

operation in a `FlatMap`

data structure but faced the same problem: evaluation of deeply nested `FlatMaps`

can still cause stack overflows. We solved that by delegating the creation of `FlatMap`

to a Trampoline’s `flatMap`

.

This article is based on Rúnar Bjarnason’s paper “Stackless Scala With Free Monads” that I definitely recommend checking out for more detailed explanation and see how Free monad is actually a generalization of a Trampoline monad.