Do it in Java 8: The State Monad
Join the DZone community and get the full member experience.
Join For FreeIn a previous article (Do it in Java 8: Automatic memoization), I wrote about memoization and said that memoization is about handling state between function calls, although the value returned by the function does not change from one call to another as long as the argument is the same. I showed how this could be done automatically. There are however other cases when handling state between function calls is necessary but cannot be done this simple way. One is handling state between recursive calls. In such a situation, the function is called only once from the user point of view, but it is in fact called several times since it calls itself recursively. Most of these functions will not benefit internally from memoization. For example, the factorial function may be implemented recursively, and for one call to f(n), there will be n calls to the function, but not twice with the same value. So this function would not benefit from internal memoization. On the contrary, the Fibonacci function, if implemented according to its recursive definition, will call itself recursively a huge number of times will the same arguments.
Here is the standard definition of the Fibonacci function:
f(n) = f(n – 1) + f(n – 2)
This definition has a major problem: calculating f(n) implies evaluating n² times the function with different values. The result is that, in practice, it is impossible to use this definition for n > 50 because the time of execution increases exponentially. But since we are calculating more than n values, it is obvious that some values are calculated several times. So this definition would be a good candidate for internal memoization.
Warning: Java is not a true recursive language, so using any recursive method will eventually blow the stack, unless we use TCO (Tail Call Optimization) as described in my previous article (Do it in Java 8: recursive and corecursive Fibonacci). However, TCO can't be applied to the original Fibonacci definition because it is not tail recursive. So we will write an example working for n limited to a few thousands.
The naive implementation
Here is how we might implement the original definition:
static BigInteger fib(BigInteger n) { return n.equals(BigInteger.ZERO) || n.equals(BigInteger.ONE) ? n : fib(n.subtract(BigInteger.ONE)).add(fib(n.subtract(BigInteger.ONE.add(BigInteger.ONE)))); }
Do not try this implementation for values greater that 50. On my machine, fib(50) takes 44 minutes to return!
Basic memoized version
To avoid computing several times the same values, we can use a map were computed values are stored. This way, for each value, we first look into the map to see if it has already been computed. If it is present, we just retrieve it. Otherwise, we compute it, store it into the map and return it.
For this, we will use a special class called Memo
. This is a normal HashMap
that mimic the interface of functional map, which means that insertion of a new (key, value)
pair returns a map, and looking up a key returns an Optional
:
public class Memo extends HashMap<BigInteger, BigInteger> { public Optional<BigInteger> retrieve(BigInteger key) { return Optional.ofNullable(super.get(key)); } public Memo addEntry(BigInteger key, BigInteger value) { super.put(key, value); return this; } }
Note that this is not a true functional (immutable and persistent) map, but this is not a problem since it will not be shared.
We will also need a Tuple
class that we define as:
public class Tuple<A, B> { public final A _1; public final B _2; public Tuple(A a, B b) { _1 = a; _2 = b; } }
With this class, we can write the following basic implementation:
static BigInteger fibMemo1(BigInteger n) { return fibMemo(n, new Memo().addEntry(BigInteger.ZERO, BigInteger.ZERO) .addEntry(BigInteger.ONE, BigInteger.ONE))._1; } static Tuple<BigInteger, Memo> fibMemo(BigInteger n, Memo memo) { return memo.retrieve(n).map(x -> new Tuple<>(x, memo)).orElseGet(() -> { BigInteger x = fibMemo(n.subtract(BigInteger.ONE), memo)._1 .add(fibMemo(n.subtract(BigInteger.ONE).subtract(BigInteger.ONE), memo)._1); return new Tuple<>(x, memo.addEntry(n, x)); }); }
This implementation works fine, provided values of n are not too big. For f(50), it returns in less that 1 millisecond, which should be compared to the 44 minutes of the naive version. Remark that we do not have to test for terminal values (f(0) and f(1). These values are simply inserted into the map at start.
So, is there something we should make better? The main problem is that we have to handle the passing of the Memo
map by hand. The signature of the fibMemo
method is no longer fibMemo(BigInteger n)
, but fibMemo(BigInteger n, Memo memo)
. Could we simplify this? We might think about using automatic memoization as described in my previous article (Do it in Java 8: Automatic memoization ). However, this will not work:
static Function<BigInteger, BigInteger> fib = new Function<BigInteger, BigInteger>() { @Override public BigInteger apply(BigInteger n) { return n.equals(BigInteger.ZERO) || n.equals(BigInteger.ONE) ? n : this.apply(n.subtract(BigInteger.ONE)).add(this.apply(n.subtract(BigInteger.ONE.add(BigInteger.ONE)))); } }; static Function<BigInteger, BigInteger> fibm = Memoizer.memoize(fib);
Beside the fact that we could not use lambdas because reference to this
is not allowed, which may be worked around by using the original anonymous class syntax, the recursive call is made to the non memoized function, so it does not make things better.
Using the State monad
In this example, each computed value is strictly evaluated by the fibMemo
method, and this is what makes the memo parameter necessary. Instead of a method returning a value, what we would need is a method returning a function that could be evaluated latter. This function would take a Memo
as parameter, and this Memo
instance would be necessary only at evaluation time. This is what the State
monad will do.
Java 8 does not provide the state monad, so we have to create it, but it is very simple. However, we first need an implementation of a list that is more functional that what Java offers. In a real case, we would use a true immutable and persistent List. The one I have written is about 1 000 lines, so I can't show it here. Instead, we will use a dummy functional list, backed by a java.util.ArrayList
. Although this is less elegant, it does the same job:
public class List<T> { private java.util.List<T> list = new ArrayList<>(); public static <T> List<T> empty() { return new List<T>(); } @SafeVarargs public static <T> List<T> apply(T... ta) { List<T> result = new List<>(); for (T t : ta) result.list.add(t); return result; } public List<T> cons(T t) { List<T> result = new List<>(); result.list.add(t); result.list.addAll(list); return result; } public <U> U foldRight(U seed, Function<T, Function<U, U>> f) { U result = seed; for (int i = list.size() - 1; i >= 0; i--) { result = f.apply(list.get(i)).apply(result); } return result; } public <U> List<U> map(Function<T, U> f) { List<U> result = new List<>(); for (T t : list) { result.list.add(f.apply(t)); } return result; } public List<T> filter(Function<T, Boolean> f) { List<T> result = new List<>(); for (T t : list) { if (f.apply(t)) { result.list.add(t); } } return result; } public Optional<T> findFirst() { return list.size() == 0 ? Optional.empty() : Optional.of(list.get(0)); } @Override public String toString() { StringBuilder s = new StringBuilder("["); for (T t : list) { s.append(t).append(", "); } return s.append("NIL]").toString(); } }
The implementation is not functional, but the interface is! And although there are lots of missing capabilities, we have all we need.
Now, we can write the state monad. It is often called simply State
but I prefer to call it StateMonad
in order to avoid confusion between the state and the monad:
public class StateMonad<S, A> { public final Function<S, StateTuple<A, S>> runState; public StateMonad(Function<S, StateTuple<A, S>> runState) { this.runState = runState; } public static <S, A> StateMonad<S, A> unit(A a) { return new StateMonad<>(s -> new StateTuple<>(a, s)); } public static <S> StateMonad<S, S> get() { return new StateMonad<>(s -> new StateTuple<>(s, s)); } public static <S, A> StateMonad<S, A> getState(Function<S, A> f) { return new StateMonad<>(s -> new StateTuple<>(f.apply(s), s)); } public static <S> StateMonad<S, Nothing> transition(Function<S, S> f) { return new StateMonad<>(s -> new StateTuple<>(Nothing.instance, f.apply(s))); } public static <S, A> StateMonad<S, A> transition(Function<S, S> f, A value) { return new StateMonad<>(s -> new StateTuple<>(value, f.apply(s))); } public static <S, A> StateMonad<S, List<A>> compose(List<StateMonad<S, A>> fs) { return fs.foldRight(StateMonad.unit(List.<A>empty()), f -> acc -> f.map2(acc, a -> b -> b.cons(a))); } public <B> StateMonad<S, B> flatMap(Function<A, StateMonad<S, B>> f) { return new StateMonad<>(s -> { StateTuple<A, S> temp = runState.apply(s); return f.apply(temp.value).runState.apply(temp.state); }); } public <B> StateMonad<S, B> map(Function<A, B> f) { return flatMap(a -> StateMonad.unit(f.apply(a))); } public <B, C> StateMonad<S, C> map2(StateMonad<S, B> sb, Function<A, Function<B, C>> f) { return flatMap(a -> sb.map(b -> f.apply(a).apply(b))); } public A eval(S s) { return runState.apply(s).value; } }
This class is parameterized by two types: the value type A
and the state type S
. In our case, A
will be BigInteger
and S
will be Memo
.
This class holds a function from a state to a tuple (value, state)
. This function is hold in the runState
field. This is similar to the value hold in the Optional
monad.
To make it a monad, this class needs a unit
method and a flatMap
method. The unit
method takes a value as parameter and returns a StateMonad
. It could be implemented as a constructor. Here, it is a factory method.
The flatMap
method takes a function from A
(a value) to StateMonad
<S, B>
and return a new StateMonad
<S, B>
. (In our case, A
is the same as B
.) The new type contains the new value and the new state that result from the application of the function.
All other methods are convenience methods:
map
allows to bind a function fromA
toB
instead of a function fromA
toStateMonad<S, B>
. It is implemented in terms offlatMap
andunit
.eval
allows easy retrieval of the value hold by theStateMonad
.getState
allows creating aStateMonad
from a functionS
->A
.transition
takes a function from state to state and a value and returns a newStateMonad
holding the value and the state resulting from the application of the function. In other words, it allows changing the state without changing the value.
There is also another transition
method taking only a function and returning a StateMonad<S, Nothing>
. Nothing
is a special class:
public final class Nothing { public static final Nothing instance = new Nothing(); private Nothing() {} }
This class could be replaced by Void
, to mean that we do not care about the type. However, Void
is not supposed to be instantiated, and the only reference of type Void
is normally null
. The problem is that null
does not carry its type. We could instantiate a Void
instance through introspection:
Constructor<Void> constructor; constructor = Void.class.getDeclaredConstructor(); constructor.setAccessible(true); Void nothing = constructor.newInstance();
but this is really ugly, so we create a Nothing
type with a single instance of it. This does the trick, although to be complete, Nothing
should be able to replace any type (like null
), which does not seem to be possible in Java.
Using the StateMonad
class, we can rewrite our program:
static BigInteger fibMemo2(BigInteger n) { return fibMemo(n).eval(new Memo().addEntry(BigInteger.ZERO, BigInteger.ZERO).addEntry(BigInteger.ONE, BigInteger.ONE)); } static StateMonad<Memo, BigInteger> fibMemo(BigInteger n) { return StateMonad.getState((Memo m) -> m.retrieve(n)) .flatMap(u -> u.map(StateMonad::<Memo, BigInteger> unit).orElse(fibMemo(n.subtract(BigInteger.ONE)) .flatMap(x -> fibMemo(n.subtract(BigInteger.ONE).subtract(BigInteger.ONE)) .map(x::add) .flatMap(z -> StateMonad.transition((Memo m) -> m.addEntry(n, z), z))))); }
Now, the fibMemo
method only takes a BigInteger
as its parameter and returns a StateMonad
, which means that when this method returns, nothing has been evaluated yet. The Memo
doesn't even exist!
To get the result, we may call the eval
method, passing it the Memo
instance.
If you find this code difficult to understand, here is an exploded commented version using longer identifiers:
static StateMonad<Memo, BigInteger> fibMemo(BigInteger n) { /* * Create a function of type Memo -> Optional<BigInteger> with a closure * over the n parameter. */ Function<Memo, Optional<BigInteger>> retrieveValueFromMapIfPresent = (Memo memoizationMap) -> memoizationMap.retrieve(n); /* * Create a state from this function. */ StateMonad<Memo, Optional<BigInteger>> initialState = StateMonad.getState(retrieveValueFromMapIfPresent); /* * Create a function for converting the value (BigInteger) into a State * Monad instance. This function will be bound to the Optional resulting * from the lookup into the map to give the result if the value was found. */ Function<BigInteger, StateMonad<Memo, BigInteger>> createStateFromValue = StateMonad::<Memo, BigInteger> unit; /* * The value computation proper. This can't be easily decomposed because it * make heavy use of closures. It first calls recursively fibMemo(n - 1), * producing a StateMonad<Memo, BigInteger>. It then flatMaps it to a new * recursive call to fibMemo(n - 2) (actually fibMemo(n - 1 - 1)) and get a * new StateMonad<Memo, BigInteger> which is mapped to BigInteger addition * with the preceding value (x). Then it flatMaps it again with the function * y -> StateMonad.transition((Memo m) -> m.addEntry(n, z), z) which adds * the two values and returns a new StateMonad with the computed value added * to the map. */ StateMonad<Memo, BigInteger> computedValue = fibMemo(n.subtract(BigInteger.ONE)) .flatMap(x -> fibMemo(n.subtract(BigInteger.ONE).subtract(BigInteger.ONE)) .map(x::add) .flatMap(z -> StateMonad.transition((Memo m) -> m.addEntry(n, z), z))); /* * Create a function taking an Optional<BigInteger> as its parameter and * returning a state. This is the main function that returns the value in * the Optional if it is present and compute it and put it into the map * before returning it otherwise. */ Function<Optional<BigInteger>, StateMonad<Memo, BigInteger>> computeFiboValueIfAbsentFromMap = u -> u.map(createStateFromValue).orElse(computedValue); /* * Bind the computeFiboValueIfAbsentFromMap function to the initial State * and return the result. */ return initialState.flatMap(computeFiboValueIfAbsentFromMap); }
The most important part is the following:
StateMonad<Memo, BigInteger> computedValue = fibMemo_(n.subtract(BigInteger.ONE)) .flatMap(x -> fibMemo_(n.subtract(BigInteger.ONE).subtract(BigInteger.ONE)) .map(x::add) .flatMap(z -> StateMonad.transition((Memo m) -> m.addEntry(n, z), z)));
This kind of code is essential to functional programming, although it is sometimes replaced in other languages with “for comprehensions”. As Java 8 does not have for comprehensions we have to use this form.
At this point, we have seen that using the state monad allows abstracting the handling of state. This technique can be used every time you have to handle state.
More uses of the state monad
The state monad may be used for many other cases were state must be maintained in a functional way. Most programs based upon maintaining state use a concept known as a State Machine. A state machine is defined by an initial state and a series of inputs. Each input submitted to the state machine will produce a new state by applying one of several possible transitions based upon a list of conditions concerning both the input and the actual state.
If we take the example of a bank account, the initial state would be the initial balance of the account. Possible transition would be deposit(amount) and withdraw(amount). The conditions would be true
for deposit and balance >= amount
for withdraw.
Given the state monad that we have implemented above, we could write a parameterized state machine:
public class StateMachine<I, S> { Function<I, StateMonad<S, Nothing>> function; public StateMachine(List<Tuple<Condition<I, S>, Transition<I, S>>> transitions) { function = i -> StateMonad.transition(m -> Optional.of(new StateTuple<>(i, m)).flatMap((StateTuple<I, S> t) -> transitions.filter((Tuple<Condition<I, S>, Transition<I, S>> x) -> x._1.test(t)).findFirst().map((Tuple<Condition<I, S>, Transition<I, S>> y) -> y._2.apply(t))).get()); } public StateMonad<S, S> process(List<I> inputs) { List<StateMonad<S, Nothing>> a = inputs.map(function); StateMonad<S, List<Nothing>> b = StateMonad.compose(a); return b.flatMap(x -> StateMonad.get()); } }
This machine uses a bunch of helper classes. First, the inputs are represented by an interface:
public interface Input { boolean isDeposit(); boolean isWithdraw(); int getAmount(); }
There are two instances of inputs:
public class Deposit implements Input { private final int amount; public Deposit(int amount) { super(); this.amount = amount; } @Override public boolean isDeposit() { return true; } @Override public boolean isWithdraw() { return false; } @Override public int getAmount() { return this.amount; } } public class Withdraw implements Input { private final int amount; public Withdraw(int amount) { super(); this.amount = amount; } @Override public boolean isDeposit() { return false; } @Override public boolean isWithdraw() { return true; } @Override public int getAmount() { return this.amount; } }
Then come two functional interfaces for conditions and transitions:
public interface Condition<I, S> extends Predicate<StateTuple<I, S>> {} public interface Transition<I, S> extends Function<StateTuple<I, S>, S> {}
These act as type aliases in order to simplify the code. We could have used the predicate and the function directly. In the same manner, we use a StateTuple
class instead of a normal tuple:
public class StateTuple<A, S> { public final A value; public final S state; public StateTuple(A a, S s) { value = Objects.requireNonNull(a); state = Objects.requireNonNull(s); } }
This is exactly the same as an ordinary tuple with named members instead of numbered ones. Numbered members allows using the same class everywhere, but a specific class like this one make the code easier to read as we will see.
The last utility class is Outcome
, which represent the result returned by the state machine:
public class Outcome { public final Integer account; public final List<Either<Exception, Integer>> operations; public Outcome(Integer account, List<Either<Exception, Integer>> operations) { super(); this.account = account; this.operations = operations; } public String toString() { return "(" + account.toString() + "," + operations.toString() + ")"; } }
This again could be replaced with a Tuple<Integer, List<Either<Exception, Integer>>>
, but using named parameters make the code easier to read. (In some functional languages, we could use type aliases for this.)
Here, we use an Either class, which is another kind of monad that Java does not offer. I will not show the complete class, but only the parts that are useful for this example:
public interface Either<A, B> { boolean isLeft(); boolean isRight(); A getLeft(); B getRight(); static <A, B> Either<A, B> right(B value) { return new Right<>(value); } static <A, B> Either<A, B> left(A value) { return new Left<>(value); } public class Left<A, B> implements Either<A, B> { private final A left; private Left(A left) { super(); this.left = left; } @Override public boolean isLeft() { return true; } @Override public boolean isRight() { return false; } @Override public A getLeft() { return this.left; } @Override public B getRight() { throw new IllegalStateException("getRight() called on Left value"); } @Override public String toString() { return left.toString(); } } public class Right<A, B> implements Either<A, B> { private final B right; private Right(B right) { super(); this.right = right; } @Override public boolean isLeft() { return false; } @Override public boolean isRight() { return true; } @Override public A getLeft() { throw new IllegalStateException("getLeft() called on Right value"); } @Override public B getRight() { return this.right; } @Override public String toString() { return right.toString(); } } }
This implementation is missing a flatMap
method, but we will not need it. The Either
class is somewhat like the Optional
Java class in that it may be used to represent the result of an evaluation that may return a value or something else like an exception, an error message or whatever. What is important is that it can hold one of two things of different types.
We now have all we need to use our state machine:
public class Account { public static StateMachine<Input, Outcome> createMachine() { Condition<Input, Outcome> predicate1 = t -> t.value.isDeposit(); Transition<Input, Outcome> transition1 = t -> new Outcome(t.state.account + t.value.getAmount(), t.state.operations.cons(Either.right(t.value.getAmount()))); Condition<Input, Outcome> predicate2 = t -> t.value.isWithdraw() && t.state.account >= t.value.getAmount(); Transition<Input, Outcome> transition2 = t -> new Outcome(t.state.account - t.value.getAmount(), t.state.operations.cons(Either.right(- t.value.getAmount()))); Condition<Input, Outcome> predicate3 = t -> true; Transition<Input, Outcome> transition3 = t -> new Outcome(t.state.account, t.state.operations.cons(Either.left(new IllegalStateException(String.format("Can't withdraw %s because balance is only %s", t.value.getAmount(), t.state.account))))); List<Tuple<Condition<Input, Outcome>, Transition<Input, Outcome>>> transitions = List.apply( new Tuple<>(predicate1, transition1), new Tuple<>(predicate2, transition2), new Tuple<>(predicate3, transition3)); return new StateMachine<>(transitions); } }
This could not be simpler. We just define each possible condition and the corresponding transition, and then build a list of tuples (Condition, Transition)
that is used to instantiate the state machine. There are however to rules that must be enforced:
Conditions must be put in the right order, with the more specific first and the more general last.
We must be careful to be sure to match all possible cases. Otherwise, we will get an exception.
At this stage, nothing has been evaluated. We did not even use the initial state!
To run the state machine, we must create a list of inputs and feed it in the machine, for example:
List<Input> inputs = List.apply( new Deposit(100), new Withdraw(50), new Withdraw(150), new Deposit(200), new Withdraw(150)); StateMonad<Outcome, Outcome> = Account.createMachine().process(inputs);
Again, nothing has been evaluated yet. To get the result, we just evaluate the result, using an initial state:
Outcome outcome = state.eval(new Outcome(0, List.empty()))
If we run the program with the list above, and call toString()
on the resulting outcome (we can't do more useful things since the Either class is so minimal!) we get the following result:
// // (100,[-150, 200, java.lang.IllegalStateException: Can't withdraw 150 because balance is only 50, -50, 100, NIL])
This is a tuple of the resulting balance for the account (100) and the list of operations that have been carried on. We can see that successful operations are represented by a signed integer, and failed operations are represented by an error message.
This of course is a very minimal example, and as usual, one may think it would be much easier to do it the imperative way. However, think of a more complex example, like a text parser. All there is to do to adapt the state machine is to define the state representation (the Outcome
class), define the possible inputs and create the list of (Condition,Transition)
. Going the functional way does not make the whole thing simpler. However, it allows abstracting the implementation of the state machine from the requirements. The only thing we have to do to create a new state machine is to write the new requirements!
Opinions expressed by DZone contributors are their own.
Comments