cps
Passing DBs Through<br>Continuations
Dedicated to the Minnowbrook<br>Analytic Reasoning Seminar with special thanks to Kris Micinski and Michael Ballantyne
Suppose you want to write a database. You'd probably start by<br>implementing relational algebra operators — projection, filter, join,<br>etc. The easy way is to implement them as functions that take in tables<br>and return tables, and assemble them into a larger expression. That was<br>how Prela worked in its first<br>incarnation. The code was clean, but it was hella slow! Which was not<br>surprising, because every operator materialized every intermediate<br>result. The standard solution to this is the iterator<br>model, where each operator implements an Iterator interface<br>that streams intermediate tables row by row instead of materializing<br>them. But implementing the iterator model naively still incurs overhead:<br>every call to Iterator.next() triggers a dynamic dispatch,<br>which costs vtable lookups and destroys cache locality. There are two<br>standard remedies: vectorization<br>and compilation. A vectorized database amortizes the overhead by<br>implementing Iterator.next_batch() which returns a whole<br>batch of data that can be processed together; a compiled database, well,<br>compiles the incoming query directly to fast machine code that runs<br>without any dynamic dispatch. Either approach takes a lot of very smart<br>people spending their entire working life to build, and it's why systems<br>like DuckDB and Umbra exist. I'm moderately smart but don't have a lot<br>of time, so I was looking for a shortcut. The shortcut I<br>stumbled upon was so beautiful that I literally cried1<br>when I finally understood it, and I hope my explanation below will make<br>you cry too :' )
To keep things simple, let's suppose we're just dealing with lists of<br>numbers, and we want to do two very simple things to them:<br>inc adds 1 to every number, and dbl doubles<br>them. That's pretty easy to write:2
inc(xs) = [x + 1 for x in xs]
dbl(xs) = [2 * x for x in xs]
Now, we can chain them together with dbl(inc(xs)) which<br>will do two steps in sequence. Problem is, because each function takes<br>in a list and returns a list, our program produces an<br>intermediate, namely inc(xs). This allocates a new<br>list only to be thrown away by the call to dbl. Things only<br>gets worse when we chain together multiple calls to inc and<br>dbl. A more efficient implementation would fuse<br>together the operations:
inc_n_dbl(xs) = [2 * (x + 1) for x in xs]
Of course, we can't write down every possible combination of<br>operators like this. Is there a way to define each operator modularly,<br>yet still have them compose into tightly fused operations automatically?<br>Yes, if we use a bit of magic from functional compilers —<br>continuation-passing style (CPS).
The key idea of CPS is to define operators that do things<br>instead of making things. inc and dbl<br>as defined above each takes in a list and makes a list.<br>Instead, the CPS version of each operator takes in a list and an<br>additional input k: this k is a function that<br>the caller passes in, specifying what it wants to do with each element<br>after the operator's work is done. k is called the<br>continuation. Let's look at some code:
function inc(xs, k)<br>for x in xs<br>k(x + 1)<br>end<br>end
Now suppose k is the print function, then<br>inc as defined above will add 1 to each number, then print<br>the result. Note that nothing is returned, and inc only<br>does its job (adding 1) then performs what it's told to (apply<br>k). As an exercise, you can try and write down<br>dbl in CPS style.
But currently each of inc and dbl still<br>takes in a list, and there's no obvious way to compose multiple<br>operators. To do that, we replace xs with a "child"<br>operator op:
inc(op, k) = op(x -> k(x + 1))
dbl(op, k) = op(x -> k(x * 2))
function scan(xs, k)<br>for x in xs<br>k(x)<br>end<br>end
Intuitively, inc now trusts its child op to<br>do its job, namely, that op will apply the continuation it<br>receives to each item. So instead of iterating over xs,<br>inc simply tags the + 1 step onto the<br>continuation and passes it to op. I've also defined a<br>"source" operator scan that connects the input list to the<br>operators. Let's see the code in action.
Start by calling inc(scan(xs), print).3
According to the definition of inc, this will call<br>scan(xs, x -> print(x + 1))
Plugging in the definition of scan, this gets us<br>for x in xs; print(x + 1); end
So chaining together inc and scan indeed<br>does what we want! Now let's try a longer chain<br>dbl(inc(scan(xs)), print):
Expanding dbl gets us<br>inc(scan(xs), x -> print(x * 2))
Expanding inc gets us<br>scan(xs, x -> print((x + 1) * 2))
Finally, expanding scan gets us<br>for x in xs; print((x + 1) * 2); end
Notice how I used the word expand — if we annotate every<br>operator definition with @inline, the compiler will<br>actually unfold the code as we did above, and an operator chain gets<br>compiled down to a fused loop in the end! You can try expanding longer<br>chains like dbl(inc(dbl(inc(scan(xs)))), print) to get some<br>practice thinking about CPS. Julia also has handy tools...