Re-Thinking the Visitor Pattern With Scala: Shapeless and Polymorphic Functions
In this article we will look at a relatively boilerplate-free way to traverse tree structures in Scala, using polymorphic functions.
Join the DZone community and get the full member experience.
Join For FreeOver the course of my career, a problem that I have had to face fairly repeatedly is dealing with a nested tree-like structure with arbitrary depth. From XML to directory structures to building data models, nested trees or documents are a common and pretty useful way to model data.
Early in my career (classic Java/J2EE/Spring days), I tackled them using the classic Visitor pattern from the Gang of Four and have probably had more than my fair share of implementing that pattern. Then, whilst working in Groovy, I re-imagined the pattern a little to make it a little more idiomatic (dealing with mostly Maps and Lists) and, now, I am working in Scala, and once again the problem has arisen.
There are lots of things that Scala handles well – I do generally like its type system, and everyone always raves about the pattern matching (which is undeniably useful), but it has always irked me a bit when dealing with child classes that I have to match on every implementation to do something. I always feel like it's something I should be able to do with type classes, and I inevitably end up a little sad every time I remember I can't. Let me explain with a quick example: let's imagine we are modeling a structure like XML. I will assume we all know XML, but the format essentially allows you to define nested tree structures of elements – an element can be a complex type, e.g. like a directory/object, that holds further children elements, or a simple type, e.g. a string element that holds a string.
sealed trait Element
sealed trait SimpleElement[A] extends Element {
def value: A
}
case class ComplexElement (value: List[Element]) extends Element
case class TextElement (value: String) extends SimpleElement[String]
case class NumberElement (value: Double) extends SimpleElement[Double]
case class BooleanElement (value: Boolean) extends SimpleElement[Boolean]
Above is a basic setup to model a tree structure; we have our sealed trait for the generic element, and we then have a class for the complex element (that is an element that can have a further list of child elements) and then a couple of basic classes for the simple elements (String/Boolean/Double).
Now, when we have a ComplexElement
and we want to process its children, a List[Element]
(ideally type classes would come to our rescue) can be used like this:
xxxxxxxxxx
sealed trait ValidatorTypeClass[A] {
def validate(a: A): Boolean
}
object ValidatorTypeClass {
def validateElement[A](a: A)(implicit v: ValidatorTypeClass[A]) = v.validate(a)
implicit def stringElementValidator = new ValidatorTypeClass[String] {
override def validate(a: String): Boolean = ??? //validation logic for strings
}
implicit def numberElementValidator = new ValidatorTypeClass[Double] {
override def validate(a: Double): Boolean = ??? //validation logic for numbers
}
implicit def booleanElementValidator = new ValidatorTypeClass[Boolean] {
override def validate(a: Boolean): Boolean = ??? //validation logic for booleans
}
implicit def complexElementValidator = new ValidatorTypeClass[ComplexElement] {
override def validate(a: ComplexElement): Boolean = a.value.forall(validateElement)
}
}
import ValidatorTypeClass._
val complex = ComplexElement(
value = List(
StringElement(value = "first element"),
StringElement(value = "second element")
)
)
validateElement(complex)
Above, we have a simple ValidatorTypeClass
for which we define our implementations for all the different types we care about, and from there, it looks relatively simple to traverse a nested structure – the type class for the ComplexElement
simply iterates through children and recursively passes to the child element type class to handle the logic.
Note: I will use validation as an example throughout this article, but that is just for the sake of a simple illustration; there are many better ways to perform simple attribute validation in Scala, but this helps provide an example context for the problem.
However, if you run the above code, you will get an error like this:
xxxxxxxxxx
could not find implicit value for parameter v: ValidatorTypeClass[Element]
override def validate(a: ComplexElement): Boolean = a.value.forall(validateElement)
The reason is that it's looking for an implicit type class to handle the parent type Element
(ComplexElement
the value attribute is defined as List[Element]
), which we haven't defined. Sure, we could define that type of class ValidatorTypeClass[Element]
, and simple pattern matches the input across all the implemented types, but at that point, there's no point having type classes, and you just end up with a big old pattern matching block. This is fine, but it feels kind of verbose, especially when you have to have the blocks repeated throughout the code as you inevitably have to handle the tree structure in several different places/ways.
So, I wanted to find a better way, and having written about Shapeless a couple of times before, I thought I'd give it a try once again.
Enter Shapeless
The good news is that Shapeless has some tools that can help improve this. The bad news is that there isn't really any documentation on some of the features (beyond reading the source code and unit tests), and some of it just doesn't seem to be mentioned anywhere at all! I had previously used a function that Shapeless provides called everywhere
. Even this function isn't really explicitly called out in the docs, but I stumbled upon it in an article about what was new in Shapeless 2.0 where it was used in an example piece of code without any mention or explanation – everywhere
allows in-place editing of tree-like structures (or any structures really) and was based on the ideas laid out in the Scrap Your Boilerplate (SYB) paper that large parts of the Shapeless library were based on.
As well as everywhere
, Shapeless also provides a function called everything
, which is also from the SYB paper, and instead of editing, it lets you simply traverse or visit generic data structures. It's pretty simple, conceptually, but finding any mention of it in docs or footnotes was hard (I found it reading the source code), so let's go through it.
everything
takes three arguments:
xxxxxxxxxx
everything(validates)(combine)(complex)
validates
is a polymorphic function for when we want to process every step of the data structure, combine
is a polymorphic function to combine the results, and complex
(the third argument above) is our input – in this case, the root of our nested data model.
So, let's start with our polymorphic function for validating every step (this will be every attribute on each class, including lists, maps, and other classes that will then get traversed as well – you can find out more about polymorphic functions and how they are implemented with Shapeless here):
xxxxxxxxxx
sealed trait DefaultValidation extends Poly1 {
implicit def default[T] = at[T](x => true)
}
object validates extends DefaultValidation {
implicit def caseValidated[A](implicit v: ValidatorTypeClass[A]) = at[A](x => v.validate(x))
}
So, what is happening here? And, why do we have two polymorphic functions? Well, let's start with our second one, validates
, which is going to be handling the validation. Remember our lovely and simple type class we defined earlier? We are going to use it here – in this polymorphic function, we simply define this implicit function that will match on any attribute it finds that has an implicit ValidatorTypeClass
in scope, and runs the validation (in our simple example, returning a boolean result for whether it passes or fails).
Now, there are also going to be other types in our structure that we essentially want to ignore – they might be simple attributes (Strings, etc) or they might be Lists, that we want to continue to traverse, but as a type in itself, we can just pass over. For this, we need a polymorphic function which is essentially a No-Op and returns true. As the cases in the polymorphic function are implicit, we need to have the default case in the parent class so it is resolved as a lower priority than our validating implicit.
So, everywhere is going to handle the generic traversal of our data structure, whatever that might look like, and this polymorphic function is going to return a boolean to indicate whether every element in the structure is okay. Now, as mentioned, we need to combine all these results from our structure.
To do that, we just define another polymorphic function with arity 2 to define how we handle, which in the case of booleans is really very simple:
xxxxxxxxxx
object combine extends Poly2 {
implicit def caseValidation = at[Boolean, Boolean] (_ && _)
}
This combinator will simply combine the booleans, and as soon as one element fails, the overall answer will be false.
Finally, if we put it all together we get some code that looks a bit like this:
xxxxxxxxxx
val complex = ComplexElement(
value = List(
StringElement(value = "first element"),
StringElement(value = "second element")
)
)
import shapeless._
sealed trait DefaultValidation extends Poly1 {
implicit def default[T] = at[T](x => true)
}
object validates extends DefaultValidation {
implicit def caseValidated[A](implicit v: ValidatorTypeClass[A]) = at[A](x => v.validate(x))
}
object combine extends Poly2 {
implicit def caseValidation = at[Boolean, Boolean] (_ && _)
}
sealed trait ValidatorTypeClass[A] {
def validate(a: A): Boolean
}
object ValidatorTypeClass {
def validateElement[A](a: A)(implicit v: ValidatorTypeClass[A]) = v.validate(a)
implicit def stringElementValidator = new ValidatorTypeClass[String] {
override def validate(a: String): Boolean = true //validation logic for strings
}
implicit def numberElementValidator = new ValidatorTypeClass[Double] {
override def validate(a: Double): Boolean = true //validation logic for numbers
}
implicit def booleanElementValidator = new ValidatorTypeClass[Boolean] {
override def validate(a: Boolean): Boolean = true //validation logic for booleans
}
implicit def complexElementValidator = new ValidatorTypeClass[ComplexElement] {
override def validate(a: ComplexElement): Boolean = true
}
}
everything(validates)(combine)(complex)
In the above example, we have a standard TypeClass solution for validating different types, and then just 9 lines of code (10-19) to implement a visitor pattern type solution that allows traversing an arbitrary depth tree of these types. And, that's it! Shapeless' everywhere
handles the boilerplate, and, with the addition of those minimal polymorphic functions, we don't need to worry about traversing anything or pattern matching on parent types. So, it ends up really quite nice. Nine extra lines of code, and our type class approach works after all!
Footnote 1: Further Removing Boilerplate
If you found yourself writing code like this a lot, you could further simplify it by changing our implicit ValidatorTypeClass
to a broader VisitorTypeClass
and provide a common set of combinators for the combined polymorphic function. Then, all you would need to do each time is provide the specific type class implementation of VisitorTypeClass
and it would just work as if by magic.
Footnote 2: A Better Validation
As mentioned, the validation example was purely illustrative, as its a simple domain to understand, and there are other, better ways to perform simple validation (at time of construction, other libraries, etc.). But, if we were to have this perform validation, rather than return booleans, we could look to use something like Validated from Cats; this would allow us to accumulate meaningful failures throughout the traversal. This is really simple to drop in, and all we would need to do is implement the combined polymorphic function for ValidatedNel
class:
xxxxxxxxxx
object combineFieldValues extends Poly2 {
implicit def caseValidation = at[ValidatedNel[String, Boolean], ValidatedNel[String, Boolean]] ({
case (a,b) => a.combine(b)
})
Thankfully, Cats ValidatedNel
is a Semigroup
implementation, so it already provides the combined method itself, so all we need to do is call that!
Note: you will need to provide a Semigroup implementation for whatever right-hand side you choose to use for Validated, but that's trivial for most types.
Hopefully, this has been interesting and even useful and comes in handy if you need to traverse tree-structured data (XML, JSON, etc.).
Published at DZone with permission of Rob Hinds, DZone MVB. See the original article here.
Opinions expressed by DZone contributors are their own.
Comments