When NumPy is too slow
If you’re doing numeric calculations, NumPy is a lot faster than than plain Python—but sometimes that’s not enough. What should you do when your NumPy-based code is too slow?
Your first thought might be parallelism, but that should probably be the last thing you consider. There are many speedups you can do before parallelism becomes helpful, from algorithmic improvements to working around NumPy’s architectural limitations.
Let’s see why NumPy can be slow, and then some solutions to help speed up your code even more.
Step #1: Before you optimize, choose a scalable algorithm
Before you start too much time thinking about speeding up your NumPy code, it’s worth making sure you’ve picked a scalable algorithm.
An O(N)
algorithm will scale much better than O(N2)
; the latter will quickly become unusable as N
grows, even when using a fast implementation.
For example, I have seen real world cases where a binary search on a sorted Python list was a lot faster than a linear search in C.
Thus, once you’re dealing with large amounts of data—almost certainly the case if NumPy is slow—choosing a good algorithm is critical to speed. You don’t want to waste time optimizing an algorithm you will need to swap out later.
NumPy’s inherent performance limits
Once you have a good algorithm, you can start thinking about optimization. Focusing on single-core execution, with no parallelism, there are three NumPy-specific reasons why NumPy-based code might not be as fast as it could be.
Bottleneck #1: Eager execution
Consider the following code:
def add_multiple(arr):
result = arr + 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
return result
Now, clearly adding those numbers over and over again is silly. A much faster implementation that runs 4× faster looks like this:
def add_multiple_efficient(arr):
return arr + (17 * 9)
However, NumPy cannot automatically transform the function from the first form to the second form, because NumPy doesn’t know anything about the overall execution. NumPy will execute every statement you give it, one by one, with no knowledge of future statements.
Put another way, individual operations may be fast, but there is no mechanism to automatically optimize a series of operations… even when the optimization is very obvious.
Bottleneck #2: Generic compiled code
Different CPUs have different capabilities, for example different sets of specialized SIMD (single instruction multiple data) operations that can significantly speed up numeric code. However, the NumPy package—or more broadly any pre-compiled package—you download from PyPI or Conda has to run on any random computer you install it on.
NumPy does provide manually created SIMD implementations of a growing number of operations, but this has to be added per-operation. The compiler can sometimes automatically generate SIMD instructions, but it will be limited to the lowest common denominator; modern CPU instructions won’t be used.
More broadly, different CPUs have different costs for different operations. Since the compiled version of NumPy you download needs to run on different CPUs, it will be compiled with a broad cost model that is vaguely accurate, but it will never quite match your CPU’s specific characteristics.
Bottleneck #3: Higher-memory usage from vectorization
Consider the following function:
import numpy as np
def mean_distance_from_zero(arr):
return np.abs(arr).mean()
This will create a temporary array of absolute values, an array that would be unnecessary if we were able to efficiently use a for
loop with NumPy.
While this increases memory usage, it also has impacts on the CPU: the CPU caches can’t be used as efficiently as different values need to be pushed in and others evicted.
Elsewhere you can learn more about the limits of NumPy’s vectorization.
Step #2: Optimizing your NumPy code
How do you address these three bottlenecks? You have multiple options:
Option #1: Manually optimizing NumPy code
There are a number of techniques you can use.
Rewriting your NumPy code to be more efficient
Even given a scalable algorithm, there may be ways to make your code more efficient by getting rid of redundant work or restructuring your code.
For example, notice that both versions of add_multiple()
above are O(N)
: they scale linearly with the size of the input.
Double the input array length, and you’ll double the runtime.
But even though they scale similarly, one implementation is much faster than the other, because it does less work per entry.
Using pre-existing native code functions
In many cases the operation you need is implemented in efficient native code in NumPy, or other libraries like SciPy, Scikit-Image, and so on. These implementations will often run faster than an implementation you would write yourself.
For example, we could implement a variance calculation ourselves:
def myvar(arr):
mean = arr.mean()
diff = arr - arr.mean()
diff **= 2
return diff.sum() / arr.size
Or we could just use numpy.var()
.
And as it turns out, numpy.var()
is slightly faster, about 1.25× on my computer.
It’s worth getting to know the NumPy APIs, the SciPy APIs, and other relevant libraries, so that you know what operations are already available to you.
Re-compile NumPy and friends targeting your specific CPU
It’s possible to recompile NumPy and other libraries so they’re specifically compiled for your CPU. In some cases this will speed up operations on your computer because the compiler will be able to use all available CPU instructions on your particular CPU. In practice, this is likely painful enough it’s not worth doing for most people, most of the time.
Option #2: Automated speedups with JAX just-in-time compilation
Instead of rewriting the code manually, you can use the JAX library to automatically use just-in-time compilation to optimize NumPy-using functions.
To simplify, jax.jit()
will take a function that uses NumPy APIs and compile it to native code on the fly.
It also has a lower-level API that gives you more control.
- Unlike NumPy’s one-operation-at-a-time eager execution, JAX’s JIT analyzes a whole function, so it can optimize code across operations.
- It generates code just-in-time so in theory it can be customized to your local CPU.
- It may decide to use multiple CPUs.
For example:
from jax import jit
@jit
def add_multiple(arr):
result = arr + 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
return result
This function runs 4× as fast as the original NumPy equivalent, and returns results 12× as fast because it also uses multiple CPUs.
This example is pretty silly, of course, because there’s no reason to write code that looks like the above. But JAX can also optimize NumPy operations where it’s not as clear how to manually make the code more efficient.
Option #3: Rewrite the code with Numba just-in-time compilation
Some operations are hard to express as full-array operations, or at least memory-inefficient, and you really want a for
loop to implement them.
JAX has some support for this with its lower-level API, but another just-in-time compilation solution is Numba, which gives you more control while still giving you the benefits of a compiler targeting your local CPU.
In particular, you can directly use a subset of Python to write code with for
loops, which then get compiled to native code.
For example, this variance implementation is 3× faster than NumPy’s on my computer:
from numba import njit
@njit
def myvar2(arr):
mean = arr.mean()
sum = np.float64(0.0)
for item in arr:
sum += (np.float64(item) - mean) ** 2
return sum / arr.size
Option #4: Ahead-of-time compilation
Numba lets you use write for
-loop style code on NumPy arrays, with the code compiled at runtime.
Another alternative is to use Cython, Rust, Fortran or other compiled languages to write a compiled version of the algorithm.
As with Numba, using a compiler allows you get automatic optimizations, and the ability to use for
loops can improve memory efficiency.
Unlike Numba, the code needs to be compiled in advance. As a result, by default compiled extensions will target generic CPUs, not your CPU specifically. You can make the compiler target your CPU, but then sharing the compiled code with other computers with different CPUs can be difficult.
Comparing the options
How does each of these options solve the three bottlenecks we considered?
Technique provides: | Cross-operation optimization | CPU-specific code | Memory efficiency |
---|---|---|---|
Manual optimization | Sometimes | Sometimes | Sometimes |
JAX | Yes (within function) | Yes | Unclear |
Numba | Yes (within function) | Yes | Yes |
Cython/Rust/Fortran/C | Yes (within function) | With extra work | Yes |
Step #3: Consider parallelism
As I discuss in more detail elsewhere, you should focus on optimizing single-core performance before you consider parallelism. Faster algorithms can give massive speedups, and beyond that it’s quite possible to get additional 3×, 10×, or even 25× speedup (that’s 2400% faster!) by utilizing tools like Numba to make your code more efficient.
Once you’ve hit the limits of single-core speedups, you can consider parallelism. Many of the tools mentioned here have the ability to use multiple cores: Numba if you opt in, Rust if you use Rayon, and JAX’s JIT may do so automatically. But if you jump straight to parallelism you may be ignoring an order of magnitude (or more!) of potential speedups.