# Reverse State Monad in Scala. Is it possible?

Hello all!

In this post we’re going to have some fun with a mind-breaking thing called Reverse State, and explore the limits of laziness in Scala along the way.

When I see something interesting implemented in a “foreign” programming language, I often have a desire to port it to Scala – just out of pure wondering how it would look. And sometimes using familiar language also allows to deeper understand the concepts presented. Some time ago I did it with a great book called “Neural Networks and Deep Learning”: here are most of the exercises from the book written in Scala.

This time a totally different thing caught my eye. It was a really nice article about Reverse State monad and it’s implementation in Haskell. I have never heard about it before, so implementing it in Scala seemed like an exciting exercise.

And it didn’t disappoint, although outcome was not the one I expected 🙂 So I decided to make it a story instead of just plain code. I’ll use the technique that Eugene Yokota applied in his “Learning Scalaz”: we’ll follow the source (in Eugene’s case it was Learn You A Haskell For Great Good) piece by piece, discussing and writing the code along the way. Let’s go!

## Prerequisites

To really follow along, reader should be familiar with `Monoid`

and `Traverse`

typeclasses as well as with `State`

monad in Scala. Also a quick read of the original article won’t hurt.

## Warmup

There’s a big introductory part that’s dedicated to different ways of scanning a data structure. Let’s not skip this block and use it as a warmup.

*Given a list of Ints, can you produce a cumulative sum of those integers? For example, if we had the list [2, 3, 5, 7, 11, 13], we want to have [2, 5, 10, 17, 28, 41].*

*There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative ST-based destructive updates and an idiomatic functional style.*

This should be simple, there’s `scanLeft`

in Scala std lib:

```
def cumulative(l: List[Int]): List[Int] =
l.scanLeft(0)(_ + _).tail
// List(2, 5, 10, 17, 28, 41)
cumulative(List(2, 3, 5, 7, 11, 13))
```

*For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list [2, 3, 5, 7, 11, 13], we want to have [41, 39, 36, 31, 24, 13].*

Pretty much the same thing here:

```
def cumulativeR(l: List[Int]): List[Int] =
l.scanRight(0)(_ + _).init
// List(41, 39, 36, 31, 24, 13)
cumulativeR(List(2, 3, 5, 7, 11, 13))
```

*As a Haskell programmer, we have the instinct to generalize things.*

Say no more. Scala developers love it too. We’ll use type classes from cats, but scalaz would work the same way:

```
def cumulative[T[_], W](input: T[W])(implicit M: Monoid[W], T: Traverse[T]): T[W]
def cumulativeR[T[_], W](input: T[W])(implicit M: Monoid[W], T: Traverse[T]): T[W]
```

*But how might we implement such a function? Let’s consider cumulative first. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. The State monad then becomes helpful.*

State is a well-known concept in Scala, so we can easily follow-up with our implementation of `cumulative`

:

```
import cats.data.State
import cats.syntax.semigroup._
import cats.{Monoid, Traverse}
// just some sugar to compensate the absence of "let-in" construct
def dup[T](x: T): (T, T) = (x, x)
def cumulative[T[_], W](input: T[W])(implicit M: Monoid[W], T: Traverse[T]): T[W] =
T.traverse(input)(b => State[W, W](s => dup(s |+| b)))
.runA(M.empty)
.value
```

Comments are required here I guess. “Traversing with State” is a powerful technique to go over some data structure, while accumulating information along the way. Processing of each next element allows you to modify the accumulated State “effect”. In this case we’re just accumulating the running sum (according to the provided `Monoid`

) and using the same sum as the result value of the state calculation.

`runA`

is analogous to `evalState`

in Haskell – it evaluates the thing and returns just the result value (ignoring the accumulated state). And, since for stack safety reasons State calculations in cats are wrapped into `Eval`

, we have to execute it to get our `value`

out.

Ok, let’s now, using the original article’s help, try to implement `cumulativeR`

with Reverse State.

## Enter the reverse state monad

As mind-boggling as it is, let’s try to digest this definition:

*The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future.*

Oh man… Well, let’s try to at least port the provided implementation to Scala. But I’ll change two things with comparison to the original:

- I’ll swap the results in the signature of the
`runF`

function to be consistent with normal`State`

in cats, where the state is returned on the left, and the value is on the right. - To implement
`cumulativeR`

we only need an`Applicative`

, so I’ll not try provide an instance of`Monad`

for`ReverseState`

at this moment.`import cats.{Applicative, Eval} final class ReverseState[S, A](val runF: Eval[S => Eval[(S, A)]]) { def map[B](f: A => B): ReverseState[S, B] = new ReverseState( runF.map(run => s => run(s).map { case (next, a) => next -> f(a) }) ) def ap[B](ff: ReverseState[S, A => B]): ReverseState[S, B] = new ReverseState(Eval.now(s => for { // thanks https://github.com/oleg-py/better-monadic-for for these fancy inline destructurings (future, x) <- run(s) (past, f) <- ff.run(future) } yield (past, f(x)) )) def run(s: S): Eval[(S, A)] = runF.flatMap(f => f(s)) def runA(s: S): Eval[A] = run(s).map(_._2) } object ReverseState { def apply[S, A](fn: S => (S, A)): ReverseState[S, A] = new ReverseState(Eval.now(s => Eval.now(fn(s)))) implicit def reverseStateApplicative[S]: Applicative[ReverseState[S, ?]] = new Applicative[ReverseState[S, ?]] { override def map[A, B](fa: ReverseState[S, A])(f: A => B): ReverseState[S, B] = fa.map(f) def pure[A](x: A): ReverseState[S, A] = ReverseState(s => s -> x) def ap[A, B](ff: ReverseState[S, A => B])(fa: ReverseState[S, A]): ReverseState[S, B] = fa.ap(ff) } }`

So this is our `ReverseState`

applicative. We’re drowning in `Eval`

s, but that’s the cost of stack safety: everything here is really similar to the original `State`

in cats, except for the `ap`

function.

And, actually, no “clever use of laziness” is happening in the `ap`

. Seems like it will show up in `flatMap`

, but so far we’re fine without it – `cumulativeR`

implementation works already:

```
import cats.data.State
import cats.syntax.semigroup._
import cats.{Monoid, Traverse}
// just some sugar to compensate the absence of "let-in" construct
def dup[T](x: T): (T, T) = (x, x)
def cumulativeR[T[_], W](input: T[W])(implicit M: Monoid[W], T: Traverse[T]): T[W] =
T.traverse(input)(b => ReverseState[W, W](s => dup(b |+| s)))
.runA(M.empty)
.value
```

We can check that it’s output is equivalent to the `scanRight`

-based implementation:

```
val intList = List(2, 3, 5, 7, 11, 13)
// both return List(41, 39, 36, 31, 24, 13)
intList.scanRight(0)(_ + _).init
cumulativeR(intList)
```

Of course, due to laziness, similar example in Haskell will not calculate anything until we explicitly ask for an element of the list or trigger the evaluation somehow else.

Let’s go ahead and implement the `Scan`

generalization as presented in the original.

*Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to ReverseState) is different.*

Since we don’t have proper universal quantification in Scala, I’ll just lift the `x`

into the type parameters list and name it `S`

(state). I could make it closer to the original using shapeless `poly`

s, but that’s not the topic of the post.

```
import cats.Traverse
import cats.data.State
import cats.syntax.traverse._
import scala.language.higherKinds
class Scan[A, B, S](step: A => State[S, B], initial: S) {
def scan[F[_]: Traverse](fa: F[A]): F[B] =
fa.traverse(step).runA(initial).value
def scanRight[F[_]: Traverse](fa: F[A]): F[B] =
fa.traverse(a => new ReverseState(step(a).runF)).runA(initial).value
}
```

*Here, we simply unwrap a given state monad action, wrap it again in our ReverseState, do the traversal, then unwrap it again.*

I find it beautiful! And it works, although I decided not to present standard deviation and max*min scans here. The former would require a lot of math and the latter needs proper composition abstractions for `Scan`

, which fall out of the scope of this post.

```
val intList = List(2, 3, 5, 7, 11, 13)
val sum = new Scan[Int, Int, Int](elem => State(s => dup(s + elem)), 0)
// List(41, 39, 36, 31, 24, 13)
sum.scanRight(intList)
val mean = new Scan[Double, Double, (Double, Int)](
elem => State { case (sum, count) =>
((sum + elem, count + 1), (sum + elem) / (count + 1))
},
(0.0, 0)
)
// List(6.833333333333333, 7.8, 9.0, 10.333333333333334, 12.0, 13.0)
mean.scanRight(intList.map(_.toDouble))
```

So that’s it! We implemented everything introduced in original post, and we did it in Scala! Except for…

## FlatMap! Where is my FlatMap ???

The true power of state lies in the ability to sequence stateful computations using `bind`

(or `flatMap`

as we know it in Scala). But does it work for `ReverseState`

?

In Haskell it definitely does. Laziness of Haskell runtime allows `bind`

to be a finite computation. Let’s take a closer look to the definition from the original article:

```
instance Monad (ReverseState s) where
mx >>= f =
ReverseState $ \s ->
let (a, past) = runReverseState mx future
(b, future) = runReverseState (f a) s
in (b, past)
```

It’s clear that there’s a circular computational dependency between `future`

and `a`

: each of them is calculated in terms of the other. But that is fine – as long as we operate on finite data, at some point next “future” state won’t be needed and Haskell runtime will evaluate only as much as required for the result to be produced.

## So what about Scala?

I would be happy to be proven wrong, but after hours of thought and experiments, after trying to wrap pretty much every tiny thing in `Eval`

, I came to conclusion that **there’s no possible way to implement** `flatMap`

for `ReverseState`

**in Scala**.

Although there’s a way to encode a circular dependency in Scala, there has to be an explicit exit from the “loop”. In other words, computation of such a circular dependency in Scala will only complete when under some runtime condition the dependency is gone. The reason is simple – JVM runtime is strict, thus it can’t suspend computations, that are not needed right now.

This restriction still allows some pretty interesting laziness tricks, like loeb function, for example. But let’s take a look at how an implementation of `flatMap`

for `ReverseState`

might look like in Scala:

```
def flatMap[B](f: A => ReverseState[S, B]): ReverseState[S, B] =
new ReverseState(Eval.later { s =>
lazy val result: Eval[(S, B)] = Eval.defer {
for {
// notice the unconditional reference to `result` itself
futureState <- result.map(_._1)
(pastState, a) <- run(futureState)
(_, b) <- f(a).run(s)
} yield (pastState, b)
}
result
})
```

The circular dependency in the `result`

is unconditional – the next leg of calculation is created regardless of any previous results.

`Eval`

won’t help here, because to work inside `Eval`

we need to sequence it with `flatMap`

. So we won’t even be able to construct our `Eval`

computation, since it would require circularly dependant `flatMap`

calls on `Eval`

. The `flatMap`

calls themselves are eager and there’s no way to avoid that.

So, depending on whether we wrap the `result`

into `Eval.defer`

, we either get an infinite loop or a stack overflow for programs that involve `flatMap`

-ing `ReverseState`

.

Seems like we reached the limits of laziness in Scala here.

Source: The Haskell Cast #1 with Edward Kmett

## Update

There’s one case though where `flatMap`

for `ReverseState`

will work properly in Scala. It’s when your state type `S`

is a lazy data structure (a standard `Stream`

, for example).

It may seem like some random exceptional fact, but actually it’s the same case of providing the runtime with a condition to stop evaluation and break the circular computational dependency. This time it’s just less explicit and takes the form of `Stream`

‘s laziness.

Thanks to Oleg Nizhnik (@odomontois) for pointing me in this direction.

## Conclusion

In this post we found out, that `ReverseState`

is not a `Monad`

in Scala. Again, I would really love to be proven wrong here, so if you happen to find a working instance – please, ping me!

It’s not a `Monad`

, but it’s an `Applicative`

, which means we still can use it in some meaningful computations 🙂

As an example of such, we looked at right-to-left stateful traversals. Big thanks to Zhouyu Qian from Capital Match for his post about `ReverseState`

in Haskell, that served as a foundation for the post you just read.

Thanks for reading!

Originally published on pavkin.ru