Over a million developers have joined DZone.

Refactoring to Functional Patterns in Scala

DZone's Guide to

Refactoring to Functional Patterns in Scala

Let's take a look at some approaches to refactoring functional code. Here, we use Scala and the State and Writer nomads to see how it's done.

· Java Zone ·
Free Resource

Java-based (JDBC) data connectivity to SaaS, NoSQL, and Big Data. Download Now.

If you have a background in Java like me, you've probably read Refactoring To Patterns. It's a very cool book about refactoring that shows you how to refactor Object Orientated code step by step and eventually reach full-blown Gang of Four design patterns. It had a huge impact on me at the time. It left me with the feeling that code is alive and wants to be rearranged this way, and that patterns emerge naturally.

Fast forward 10 years, and I work at a very cool startup (Bigpanda) where we use Scala and functional programming for our back-end services. Working with FP is very different, and in my opinion, far easier. There are no more complicated class hierarchies and no more complicated patterns, only functions, functions, functions. If GoF design patterns are no longer our destination, then the refactoring steps we must take are very different.

In this post, I'll introduce you to some approaches to refactoring functional code. I will build from simple refactorings to more advanced ones using the State and Writer monads — functional design patterns.

Make Sure You Have a Full Suite of Tests With Good Coverage

Refactoring without tests is like jumping without a safety net. You can use sbt as a very useful continuous test runner:

~testQuick io.company.service.PricingServiceSpec

Each time you save your file, it will recompile it and rerun only the previously failing tests.

Eliminate Primitive Types With Value Classes


def buy(lastPrice: String, name: String, exchange: String): (Double, Long) = ???


case class Symbol(name: String) extends AnyVal
case class Exchange(name: String) extends AnyVal
case class Price(value: Long) extends AnyVal
case class Timestamp(ts: Long) extends AnyVal

def buy2(lastPrice: Price, symbol: Symbol, exchange: Exchange): (Price, Timestamp) = ???

We have a package called types, and we will put all our value classes in a values.scala file. We will also add Ordering implicits there.

implicit val timestampOrdering: Ordering[Timestamp] = Ordering.by(_.ts)

Rewrite on the Side and Then Switch the Functions

Typically, I do not start by deleting old code. First, I write the new function on the side, make sure it compiles, then switch the old ones and make sure the tests pass. This is a very handy trick to let you backtrack quickly if you have an error somewhere.

Align the Types Between Functions

If your functions compose together in a natural way, it means that you have found the right level of abstraction.

Keep them small and focused on one thing, and add type annotations for the return types to increase readability.

If you find that you need to work hard with type transformations to be able to compose your functions together, then try this:

  • Inline, inline, inline, and retry.

After a while, you get that hang of it and your functions will be focused and composed together. You can also do some upfront design.

Personally, I found A Type Driven Approach to Functional Design helpful. It's in Haskell, but it is still very relevant and will give you a sense of how to design functions that compose together.

Use State Monads for Functions That Need Previously Computed Values

Let's define some types to work with:

sealed abstract class CreditRating(val rating: Int)

case class Good(points: Int) extends CreditRating(points)
case class Fair(points: Int) extends CreditRating(points)
case class Poor(points: Int) extends CreditRating(points)

case class PriceEvent(symbol: Symbol, price: Price, ts: Timestamp)

In any meaningful service, you will need previously computed data. You'll also want to persist it in case you crash or restart your app. This leads to stateful functions.

In order to rate a stock, for example, we need the previous prices and rating. This usually leads to long, ugly parameter lists. Because our data structures are immutable, we have to return new, updated versions of them.


def rateStock(historicalPrices:  Map[Symbol, List[(Timestamp, Price)]],
             lastRatings: Map[Symbol, CreditRating],
             symbol: Symbol, 
             newPrice: Price): (Map[Symbol, CreditRating], List[(Timestamp, Price)]) = ???

Quite ugly!

This is a very common pattern in FP. You can use a State monad to communicate to the reader that they are about to deal with stateful functions.

We use cat's State.

We encapsulate the moving parts in our own defined type:

case class StockState(historicalPrices: Map[Symbol, List[(Timestamp, Price)]],
                      lastRatings: Map[Symbol, CreditRating])

We use State to clean up the parameter list and the return type:

import cats.implicits._
import cats.data.State

def rateStatefulStock(symbol: Symbol, newPrice: Price): State[StockState, CreditRating] = ???

We can improve the type readability with a type alias.


type StateAction[A] = State[StockState, A]
def rateStatefulStock(symbol: Symbol, newPrice: Price): StateAction[CreditRating] = ???

The function is far cleaner and can compute and update the ratings from the previous state.

This gives us the ability to chain state functions as follows and be guaranteed that each function receives the correct, latest updated state. Very cool!

for {

    a <- rateStatefulStock(Symbol("AAPL"), Price(145.5))  

    // something magical happens here, 
    // it passes on the correct StockState to the next function

    s <- rateStatefulStock(Symbol("SAMSNG"), Price(2123.3))   

} yield (a, s)

Use a Writer Monad to Track State Transitions When Using State

If you work with event sourcing, you will want to recreate your state from all the transitions you carried out. In order to keep track of state transitions without complicating your function, you can use a Writer monad to log all the transitions in a List.

First, let's define some more types:

sealed trait Transition
case class UpgradedRating(newRating: CreditRating) extends Transition
case class DowngradedRating(newRating: CreditRating) extends Transition

We want to use State and Writer together, so let's use WriteT to combine them:

import cats.data.WriterT 

type StateActionWithTransitions[A] = WriterT[StateAction, List[Transition], A]

Use this function to log transitions and add to the final transition list:

def archive(evts: List[Transition]): StateActionWithTransitions[Unit] =

Boilerplate to wire up State and Writer together:

def lift[A](s: StateAction[A]): StateActionWithTransitions[A] =

Pure functions have simple return types that are not wrapped in StateActionWithTransitions. This tells the reader that this function does not change the state.

def calculateRating(stock: Symbol, old: CreditRating, newPrice: Price): CreditRating = 
    if (stock.name == "AAPL") Good(1000) else if(newPrice.value == 0) Poor(0) else Fair(300) 

def calculateTransition(oldRating: CreditRating, newRating: CreditRating): Transition = 
    if(newRating.rating > oldRating.rating) UpgradedRating(newRating) else DowngradedRating(oldRating)

Stateful functions have the return type StateActionWithTransitions. This tells the reader to pay special care because this function uses or updates the state:

import com.softwaremill.quicklens._ 

def setNewRating(symbol: Symbol, newRating: CreditRating): StateActionWithTransitions[Unit] = 
   lift(State.modify(_.modify(_.lastRatings).using(_ + (symbol -> newRating))))

def getRating(s: Symbol): StateActionWithTransitions[CreditRating] = 
    lift(State.inspect[StockState, CreditRating](_.lastRatings.get(s).getOrElse(Poor(0))))

Here is the final version of our function:

  • Whenever the reader sees <-, they know to pay special attention, as it is a stateful function
  • Whenever the reader sees =, they know it's a pure function and nothing related to state happens there
def rateStatefulStock(symbol: Symbol, newPrice: Price): StateActionWithTransitions[CreditRating] =
    for {
        oldRating <- getRating(symbol)
        newRating = calculateRating(symbol, oldRating, newPrice)
        _ <- setNewRating(symbol, newRating)
        transition = calculateTransition(oldRating, newRating)
        _ <- archive(transition::Nil)
    } yield newRating


  • Before refactoring, make sure you have good tests with decent coverage
  • Strongly type as much as you can. Use meaningful names and abstractions
  • Design your functions so their types align and compose together
  • Use cats's State data type to write functions that need state
  • Use type aliases to clean up boilerplate types
  • Use cat's Writer data type to log state transitions

Connect any Java based application to your SaaS data.  Over 100+ Java-based data source connectors.

scala ,design pattens ,functional programming ,java ,monads ,refactoring ,tutorial

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}