Testing the compiler optimizations your code relies on
In a recent article by David Lattimore, he demonstrates a number of Rust performance tricks, including one that involve writing code that looks like a loop, but which in practice is optimized down to a fixed number of instructions. Having what looks like an O(n) loop turned into a constant operation is great for speed!
But there’s a problem with this sort of trick: how do you know the compiler will keep doing it? What happens when the compiler’s next release comes out? How can you catch performance regressions?
One solution is benchmarking: you measure your code’s speed, and if it gets a lot slower, something has gone wrong. This is useful and important if you care about speed. But it’s also less localized, so it won’t necessarily immediately pinpoint where the regression happened.
In this article I’m going to cover another approach: a test that will only pass if the compiler really did optimize the loop away.
An example compiler optimization
My main example uses the Numba compiler which builds on LLVM; I’ll mention some Rust-specific concerns later.
The following function looks like a loop, but should be optimized down to a constant number of instructions:
from numba import jit
# Functions decorated with @jit are compiled to machine code
# the first time they are called.
@jit
def range_sum(n):
result = 0
for i in range(1, n + 1):
result += i
return result
assert range_sum(4) == 10 # = 4 + 3 + 2 + 1
On the other hand, this function will stay as a O(n) loop:
from math import log
@jit
def range_sum_of_logs(n):
result = 0
for i in range(1, n + 1):
result += log(i)
return result
assert range_sum_of_logs(3) == log(3) + log(2) + log(1)
A starting point: measuring run time
How can we prove that range_sum()
is being optimized such that it
takes a constant amount of time? By running it with different input
sizes and comparing the run time. The problem with run time is that it’s
noisy:
from time import time
results = []
for i in range(1_000):
start = time()
range_sum_of_logs(100_000)
results.append(time() - start)
print(
"Maximum elapsed time was "
f"{max(results) / min(results):.4}× the minimum"
)
Maximum elapsed time was 1.594× the minimum
To get around this problem of noise, we can run the function many times and average the run time:
from time import time
def timeit(f, *args, **kwargs):
start = time()
for _ in range(1000):
f(*args, *kwargs)
return (time() - start) / 1000
Now that we have a slightly more reliable measure of speed, we can see
how range_sum()
’s speed changes for different input sizes. And as
promised, range_sum()
takes approximately the same amount of time
regardless of the input size:
print("range_sum(10_000): ", timeit(range_sum, 10_000))
print("range_sum(100_000):", timeit(range_sum, 100_000))
range_sum(10_000): 1.3136863708496095e-07
range_sum(100_000): 1.4185905456542968e-07
In contrast, range_sum_of_logs()
runs 10× slower for a 10× larger
input. In other words, it is O(n):
print("range_sum_of_logs(10_000): ",
timeit(range_sum_of_logs, 10_000))
print("range_sum_of_logs(100_000):",
timeit(range_sum_of_logs, 100_000))
range_sum_of_logs(10_000): 0.00010006403923034668
range_sum_of_logs(100_000): 0.0009971213340759276
Using run time as a metric works, but it requires us to run the function many times to reduce noise. So we want a different measure that is more consistent across runs, and can still tell us if our code is constant time or not.
A better metric: CPU instructions
If the compiler is optimizing our code into constant time, that means that when the code is run, it will run a constant number of machine code instructions. So we can measure the number of CPU instructions used when we run the function, and if it’s sufficiently close across different input sizes, that means the compiler generated constant-time code.
It’s worth keeping in mind that the number of CPU instructions doesn’t always tell you how fast code is. Effects like branch misprediction and instruction-level parallelism (covered in my upcoming book) mean that a function that takes more CPU instructions to run might be faster. But that’s for different implementations. In this case, we’re only looking at a single implementation, and we’re interested in whether it runs in constant time regardless of input size. Runtime CPU instructions are sufficient to measure this, at least.
On Linux you can do this with the perf
subsystem; here I’ll be
accessing it using the
py-perf-event
package; a
cross-platform alternative is the
pypapi
package.
from py_perf_event import measure, Hardware
def count_instructions(f, *args, **kwargs):
# While loop is needed to workaround rare issue where
# result is 0, typically this will only run once.
while True:
result = measure(
[Hardware.INSTRUCTIONS], f, *args, **kwargs
)
if result[0] != 0:
break
return result[0]
This measure is far less noisy than the run time. We can see this if we run it individually:
print(count_instructions(range_sum_of_logs, 100_000))
print(count_instructions(range_sum_of_logs, 100_000))
print(count_instructions(range_sum_of_logs, 100_000))
6138911
6138911
6138911
Depending on the run, either there is no difference at all, or the differences are so tiny so as not to matter. Which means we can run our function just once, unlike measuring run time which requires us to run the function many times.
And now we can reliably and quickly check if a function is constant time regardless of input size:
def is_constant_time(function, small_arg, large_arg):
"""Return if the function does constant time work.
Make sure `large_arg` is 10× as big as `small_arg`.
"""
small_instrs = count_instructions(function, small_arg)
large_instrs = count_instructions(function, large_arg)
# We allow for noise of up to 50%, given the requirement
# that the larger input be 900% larger than the smaller
# one.
return (large_instrs / small_instrs) < 1.5
print(
"Is range_sum() constant time?",
is_constant_time(range_sum, 10_000, 100_000),
)
print(
"Is range_sum_of_logs constant time?",
is_constant_time(range_sum_of_logs, 10_000, 100_000),
)
Is range_sum() constant time? True
Is range_sum_of_logs constant time? False
This check can be easily integrated into a test suite.
Using this technique in Rust
In Rust, you can measure CPU instructions on Linux with the
perf-event2
crate. But there are a few
more things to worry about.
First, if you hardcode arguments to a function, that is giving the
compiler more information, to the point where it might optimize the
function in ways that won’t happen in real-world usage. This wasn’t a
problem with the Numba examples above since the inputs come from Python.
The solution is to use
std::hint::black_box()
to hide information from the compiler; see the linked documentation for
details.
The other problem is that Rust tests by default are compiled without optimizations, and even with optimization enabled the presence of debug assertions may still result in reduced optimizations. This wasn’t a problem with the Numba examples above, since Numba always compiles with optimizations on by default. To solve this you can write the tests like this:
#[cfg(test)]
mod tests {
#[cfg(not(debug_assertions))]
#[test]
fn myfunc_is_constant_time() {
// ... Write the test here ...
}
}
By default this test will be skipped. But if you run tests with
cargo test --profile=release
the extra test will run, and on optimized
code.
Testing your performance guarantees
Similar techniques can be used in other contexts. This article was inspired by some of the testing approaches in the Axiom object database/ORM, which is built on top of SQLite. When you run a SQLite query, you can measure how many SQLite bytecode instructions were run, which is far more consistent than measuring run time (you can use this API on Python, for example).
Axiom uses this functionality to test that queries’ performance is behaving as expected. For example, you can write a test that checks if the query is efficient or not, by running the same query on multiple table sizes. If the query is inefficient and results in a linear scan, the number of bytecodes executed will grow linearly with the size of the table. But if the query is able to use indexes, the number of bytecodes executed will not be linear, it should be a pretty small number regardless of the size of the database table.
More broadly, with a little creativity you can test many performance guarantees, without having to rely solely on benchmarks.