During our training and mentoring, we often hear students say “I’m not sure what to do here” when faced with trying to choose what function to call in order to solve their problem. First, it’s completely normal; nobody instantly knows the answer. To try to help, one piece of advice we’ve found to be useful is in the form of a heuristic:
The solution is usually either a
map
,flatMap
, orfold
.*
In this post we’ll be discussing the power of fold
s: how they are used to summarize a data structure, how summarization can be decomposed and abstracted, and how that abstraction can be leveraged to provide better performance in a generic way.
*And if it isn’t one of those, it’s probably traverse
.
Decomposing Summarization
If a List
contains numbers, we can compute the sum of the elements using a fold:
List(1, 2, 3).foldLeft(0)((sum, i) => sum + i)
// res0: Int = 6
foldLeft
has a rather generic signature, so let’s factor out the necessary knowledge to summarize our List
of Int
s:
- We need a starting value to initialize our summary, in case the list is empty:
ifEmpty: Int
; and - We need a way to combine the previous summary to the next element:
combine: (Int, Int) => Int
.
Let’s write a helper function summarize
to name these parts, and just delegate to foldLeft
as the implementation:
def summarize(is: List[Int], ifEmpty: Int, combine: (Int, Int) => Int): Int =
is.foldLeft(ifEmpty)((sum, i) => combine(sum, i))
summarize(List(1, 2, 3), 0, _ + _)
// res1: Int = 6
But how could we summarize any list, not just a list of numbers? Let’s allow the caller to choose the element type, and adjust the ifEmpty
and combine
helpers:
def summarize2[A](as: List[A], ifEmpty: A, combine: (A, A) => A): A =
as.foldLeft(ifEmpty)((sum, a) => combine(sum, a))
summarize2[Int](List(1, 2, 3), 0, _ + _)
// res2: Int = 6
Now the caller can summarize a list any way they want. But it’s annoying to have to supply ifEmpty
and combine
for every call; those parameters will (almost) always be the same for every type A
.
Capturing Common Behavior in a Typeclass
Let’s factor those two helpers into a typeclass:
// n.b. exactly the same as cats.Monoid
trait Monoid[A] {
def empty(): A
def combine(a1: A, a2: A): A
}
And then implicitly provide the Monoid
containing our empty
and combine
helpers:
def summarize3[A](as: List[A])(implicit M: Monoid[A]): A =
as.foldLeft(M.empty)((sum, a) => M.combine(sum, a))
Let’s create an Monoid[Int]
instance so we can use it in summarize3
:
implicit val intMonoid: Monoid[Int] =
new Monoid[Int] {
def empty(): Int = 0
def combine(i1: Int, d2: Int): Int = i1 + d2
}
Finally we’re back where we started:
summarize3(List(1, 2, 3))
// res3: Int = 6
summarize3
will now work for lists of any type A
, as long as A
has a Monoid[A]
.
Deriving foldMap
But what if we have a list of some type that doesn’t, or can’t, have a Monoid
? To compute a summary for it, we’ll need to transform our list into a type that does have a Monoid
. Is there a way to do this?
There is, and it’s a common strategy in functional programming: we ask for help from the caller, since foldMap
itself can’t know what to do. If the caller provides us with a function A => B
, and B
has a Monoid
, then we can summarize. Aha!
This function is usually called foldMap
, because we’re both fold
-ing (the list) and map
-ping the elements (the A => B
function):
// Summarize a List[A] as a B, if B has a Monoid.
def foldMap[A, B](as: List[A])(f: A => B)(implicit M: Monoid[B]): B =
as.foldLeft(M.empty)((b, a) => M.combine(b, f(a)))
If A
already has a Monoid
, then we can summarize using foldMap
by not transforming the data. That is, we transform it with the identity
function:
def sum[A: Monoid](as: List[A]): A =
foldMap(as)(identity)
sum(List(1, 2, 3, 4, 5))
// res4: Int = 15
It’s not very exciting, but we can count the length of a list too (implicitly using the Monoid[Int]
to “increment”):
def count[A](as: List[A]): Int =
foldMap(as)(_ => 1)
// ^
// increment count by 1
// for each element
count(List(1, 2, 3, 4, 5))
// res5: Int = 5
Computing the Mean with foldMap
What can you do with the sum and the count of a list? Compute the mean!
val l = List(1, 2, 3, 4, 5)
sum(l)
// res6: Int = 15
count(l)
// res7: Int = 5
def mean(is: List[Int]): Double =
sum(is).toDouble / count(is)
mean(l)
// res8: Double = 3.0
It works!
Computing the Mean with foldMap
: Improved!
For the performance-minded, there’s an issue with the computation of the mean: it processes the list twice, once for the sum and once for the count. Could we compute the mean in only one pass?
One way to think about how we could do this would be to fill in, as best we can, the parameters to one call to foldMap
, in order to compute the answer we’re looking for. That is, we need foldMap
to return the sum AND the count:
def onePassMean(is: List[Int]): Double = {
// need sum AND count from foldMap
val (sum, count) =
foldMap(is)(i => ???)
sum.toDouble / count
}
If foldMap
is to return an (Int, Int)
tuple, the A => B
mapping function–the ???
above–needs to return (Int, Int)
, and foldMap
also requires a Monoid[(Int, Int)]
def onePassMean(is: List[Int]): Double = {
val (sum: Int, count: Int) =
foldMap(is)(i => (???): (Int, Int))
// ^
// needs to produce (Int, Int)
sum.toDouble / count
}
// error: could not find implicit value for parameter M: repl.Session.App.Monoid[(Int, Int)]
// foldMap(is)(i => (???): (Int, Int))
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We know the values of components of the tuple: the first needs to be the sum and the second needs to be the count. When foldMap
uses the Monoid[(Int, Int)]
to combine tuples, it needs to work like this:
val sumAndCount = (10, 4)
// sumAndCount: (Int, Int) = (10, 4)
// ^ ^
// ^ previous count
// previous sum
val toAdd = (5, 1)
// toAdd: (Int, Int) = (5, 1)
// ^ ^
// ^ increment the count by 1
// increment the sum by 5
val totals = (10 + 5, 4 + 1)
// totals: (Int, Int) = (15, 5)
This generalizes to any tuple: combining two tuples produces one tuple with combined components. Sounds familiar? Here’s the Monoid
we need! (if there exists a Monoid
for the each component of the tuple)
implicit def tuple2Monoid[A, B](implicit MA: Monoid[A], MB: Monoid[B]): Monoid[(A, B)] =
new Monoid[(A, B)] {
// the empty tuple is a tuple of empty values from each Monoid
def empty(): (A, B) = (MA.empty, MB.empty)
// combine A fields via MA, combine B fields via MB
def combine(ab1: (A, B), ab2: (A, B)): (A, B) =
(ab1, ab2) match {
case ((a1, b1), (a2, b2)) =>
(MA.combine(a1, a2), MB.combine(b1, b2))
}
}
tuple2Monoid.combine((10, 4), (5, 1))
// res11: (Int, Int) = (15, 5)
Now we can compute our mean with one pass over the data:
def onePassMean(is: List[Int]): Double = {
val (sum, count) =
foldMap(is)(i => (i, 1))
// ^ ^
// ^ increment count by 1
// increment sum by i
sum.toDouble / count
}
onePassMean(l)
// res12: Double = 3.0
Summary
- Summarizing a
List
is a fold. - The initial value and combining operation of elements for a summary can be abstracted over by a
Monoid
. foldMap
summarizes aList
by requiring each element be able to be transformed into something that has aMonoid
.foldMap
is available in cats as part of theFoldable
typeclass, which models foldable (“summarizable”) types likeList
.- We can derive a typeclass instance for tuples of monoids, which then gives us a way to combine values “in parallel”. Typeclass instances for tuples are included in cats, and are usually imported into scope via the wildcard
import cats.implicits._
.
// The same example as above, but using cats.
import cats.implicits._
List(1, 2, 3, 4, 5).foldMap(i => (i, 1))
// res13: (Int, Int) = (15, 5)
Further reading:
- A tutorial on the universality and
expressiveness of fold by Graham Hutton. - Beautiful folds by Gabriel Gonzalez.
- Algebird, a library from Twitter for building aggregations, used in systems like Summingbird.