Making Ast.walk 220x Faster

palashawas1 pts0 comments

Making ast.walk 220x Faster<br>NewReflex Agent Toolkit is launching. Get early accessGet early access

Blog<br>Engineering

In our AI reflex-app builder we generate massive amounts of Python code. Sometimes, this code generation fails in rather trivial manners; positional parameters after keyword ones, returns with values in async generators, using outdated syntax conventions from previous versions of our framework, etc.<br>Running reflex compile will eventually find all of those bugs, but it only finds one issue at a time. That means if the AI made multiple mistakes, we are increasing the latency massively for what could be relatively simple fixes.<br>As such, we decided using a linter would be the best approach to fix this. And since we need to add reflex-specific rules, we couldn't use an existing one, and had to build our own.<br>Our initial linter looked simple:

Unfortunately, this approach ran into performance problems quickly, as we are processing a lot of generated code.<br>Notably, the slowest part in the above code is not the isinstance checks, it is ast.walk. Of course, there are many ways of optimizing things in ways that aren't making ast.walk faster, and we implemented those "lower-hanging-fruit" changes first. However, we quickly realized it was hard to make the linter significantly faster without taking on the challenge of making walk itself faster.<br>However, walking an abstract tree (which for the purposes of this code, is just a regular tree) doesn't have to be slow. Walking the difflib module took ~2ms on my device, for ~7,000 nodes. On its own, this isn't horrible, but it quickly adds up. A rough calculation gives us ~285 nanoseconds per node, on the order of a thousand CPU cycles - far more than such a simple traversal should need. So what on earth is ast.walk doing?

First thing to note here is the use of yield. Generators and yield syntax are a powerful feature of Python, but they come at a sharp cost: suspending the execution of the loop and unsuspending it repeatedly in a hot path where we are consuming the full list anyways. Sure, it saves memory, but that's not our problem at the moment, especially if that list would get cleaned up. If we simply store a list, and keep appending to it, we can minimize this:

But after running it, I realized it only gets us, 5% improvement. Much less than I expected, and it makes me wonder: what is iter_child_nodes doing?

Ah, another generator. If we inline it, we should see some decent performance improvements:

We do get the improvement we are looking for, around 25%. That begs the question: where is the remaining 75%? And while we're on the topic, what is iter_fields? Hello Python?

Here we go again, another generator. We are also yielding a tuple, which we don't end up using, as we only care about the value, not the name.<br>getattr(node, field, None) should be faster, as the exception handling can be faster when delegated to CPython.<br>Let's combine both of those:

That gets us to around 50% cumulative improvement. This removes the last function we're calling from the ast modules, so it's only us to blame here.<br>Twice as fast isn't bad, but is walking at 2x speed called sprinting? I don't think so. Let's push this further.<br>We can read _fields and check the subclassing in the same call. No other object that exists within ._fields can appear in a well-formed ast tree, so we should be safe.

That only adds another incremental improvement though, maybe 55% cumulative. At this point I'm reaching the ends of what I can do with Python. Making this iterative instead of recursive barely moves the needle.<br>But Python has one last trick up its sleeve: bindings. It allows us to write this logic in native machine code (or something that would compile to it). While I could use C here, I opted to use rust just because that's what I'm used to. Let's do a simple transliteration:

ExpandCollapse<br>I'm using cast_unchecked as PyO3 would do heavy type checking if I used regular casting. Sometimes that type checking is useful, but not here. Other than that, the above code is not so different, getattr_opt is getattr(..., ..., None).<br>That gets around 78% cumulative improvement. Nice!<br>This also allows us to do something more interesting. See, since we are doing a lot of getattr, which compiles down to reading the dictionary, we can simply iterate over the dictionary itself. In Python, that dictionary is called __dict__. We can simply read it at a memory offset:

We can also improve our subclass checking. There's only 132 classes that subclass ast.AST, so instead of a real isinstance call we can store the memory addresses of all of those classes in a set and simply check membership. (I first reached for fastset here, but it's built for small, dense integers and panics on 64-bit pointer values, so a plain hash set it is.)

Then we combine both of those:

That gets us to ~93%. In total, that's ~14 times faster.<br>The only CPython call we have left is inside of BorrowedDictIter which calls...

faster code walk python making improvement

Related Articles