Let’s build and optimize a Rust extension for Python

If your Python code isn’t fast enough, you have many options for compiled languages to write a faster extension. In this article we’ll focus on Rust, which benefits from:

  • Modern tooling, including a package repository called crates.io, and built-in build tool (cargo).
  • Excellent Python integration and tooling. The Rust package (they’re known as “crates”) for Python support is PyO3. For packaging you can use setuptools-rust, for integration with existing setuptools projects, or for standalone extensions you can use Maturin.
  • Memory- and thread-safe, so it’s much less prone to crashes or memory corruption compared to C and C++.

In particular, we’ll:

  • Implement a small algorithm in Python.
  • Re-implement it as a Rust extension.
  • Optimize the Rust version so it runs faster.

Counting unique values, the approximate way

As a motivating example, we’re going to look at the task of counting how many unique values there are in a list. If we want to get an exact answer, this is very easy to implement:

def count_exact(items):
    return len(set(items))

A set() can only contain an item once, so this de-duplicates all items and then counts them.

The problem with this solution is memory usage. If we have 10,000,000 unique values, we’ll create a set with 10,000,000 entries in it, which will use quite a lot of memory.

So if we worry about memory usage, we can use a probabilistic algorithm that gives us an approximate answer. In many situations this will be good enough, and it can use a lot less memory. We’ll be using a very simple algorithm by Chakraborty, Vinodchandran, and Meel:

import random
import math

# Implementation of https://arxiv.org/abs/2301.10191
def count_approx_python(items, epsilon=0.5, delta=0.001):
    # Will be used to scale tracked_items upwards:
    p = 1
    # Items we're currently tracking, limited in length:
    tracked_items = set()
    # Max length of tracked_items:
    max_tracked = round(
        ((12 / (epsilon ** 2)) *
        math.log2(8 * len(items) / delta)
    )
    for item in items:
        tracked_items.discard(item)
        if random.random() < p:
            tracked_items.add(item)
        if len(tracked_items) == max_tracked:
            # Drop tracked values with 50% probability.
            # Every item in tracked_items now implies the
            # final length is twice as large.
            tracked_items = {
                item for item in tracked_items
                if random.random() < 0.5
            }
            p /= 2
            if len(tracked_items) == 0:
                raise RuntimeError(
                    "we got unlucky, no answer"
                )
    return int(round(len(tracked_items) / p))

Running an example

Let’s look at an example set of words:

WORDS = [str(i) for i in range(100_000)] * 100
random.shuffle(WORDS)

We have 100,000 distinct words, which means count_exact() will create a set of size 100,000. Meanwhile, count_approx() will have a set of size 1739. It ends up having two sets at a time occasionally, but it will still use only 3% as much memory as the accurate algorithm.

We can compare the results:

print("EXACT", count_exact(WORDS))
print("APPROX", count_approx_python(WORDS))

I ran this three times in a row, and got:

EXACT 100000
APPROX 99712

EXACT 100000
APPROX 99072

EXACT 100000
APPROX 100864

The approximate version varies, but it’s pretty close—while using much less memory.

A speed comparison

We can compare the speed of the two implementations:

def timeit(f, *args, **kwargs):
    start = time.time()
    for _ in range(10):
        f(*args, **kwargs)
    print(f.__name__, (time.time() - start) / 10)

timeit(count_exact, WORDS, "seconds")
timeit(count_approx_python, WORDS, "seconds")

This gives us:

count_exact 0.14 seconds
count_approx_python 0.78 seconds

Our new function uses less memory, but it’s 5× slower. Let’s see if we can speed up count_approx_python() by rewriting it in Rust.

Going faster: Creating our Rust project

Using the Maturin Python packaging tool, we can create a new Rust Python extension very quickly, and using PyO3 we can easily interact with Python objects.

1. Initializing the project with Maturin

You can pipx or pip install maturin, and then you can use Maturin to initialize a whole new PyO3-based project:

$ maturin new rust_count_approx
✔ 🤷 Which kind of bindings to use?
  📖 Documentation: https://maturin.rs/bindings.html · pyo3
  ✨ Done! New project created rust_count_approx
$ cd rust_count_approx/

This creates all the files we need for a basic Rust-based Python package:

$ tree
├── Cargo.lock
├── Cargo.toml
├── pyproject.toml
└── src
    └── lib.rs

We can pip install the package right from the start, no further setup needed:

$ pip install .
...
Successfully installed rust_count_approx-0.1.0

2. Adding dependencies

Rust doesn’t include a built-in random number generation library, so we will add a third-party crate (Rust jargon for an installable package) using Rust’s build/package manager cargo:

$ cargo add rand
    Updating crates.io index
      Adding rand v0.8.5 to dependencies
             Features:
             + alloc
             + getrandom
             + libc
             + rand_chacha
             + std
             + std_rng
             - log
             - min_const_gen
             - nightly
             - packed_simd
             - serde
             - serde1
             - simd_support
             - small_rng

This updates the Cargo.toml file that among other things lists the Rust dependencies required to build our code. The relevant section now looks like this:

[dependencies]
pyo3 = "0.22.0"
rand = "0.8.5"

The pyo3 dependency was added automatically by Maturin when we initialized the project template.

The features list in the command output are flags that can be enabled at compile time to add more functionality; we’ll be using one later one.

3. A first Rust version

Now, we need to update the Rust code in src/lib.rs to implement our function:

use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PySequence, PySet};
use rand::random;

#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
    py: Python<'_>,
    items: &Bound<PySequence>,
    epsilon: f64,
    delta: f64,
) -> PyResult<u64> {
    let mut p = 1.0;
    let mut tracked_items = PySet::empty_bound(py)?;
    let max_tracked =
        ((12.0 / epsilon.powi(2)) *
         (8.0 * (items.len()? as f64) / delta).log2()
        ).round() as usize;
    // In future versions of PyO3 iter() will be
    // renamed to try_iter():
    for item in items.iter()? {
        let item = item?;
        tracked_items.discard(item.clone())?;
        if random::<f64>() < p {
            tracked_items.add(item)?;
        }
        if tracked_items.len() == max_tracked {
            let mut temp_tracked_items =
                PySet::empty_bound(py)?;
            for subitem in tracked_items.iter() {
                if random::<f64>() < 0.5 {
                    temp_tracked_items.add(subitem);
                }
            }
            tracked_items = temp_tracked_items;
            p /= 2.0;
            if tracked_items.len() == 0 {
                return Err(PyRuntimeError::new_err(
                    "we got unlucky, no answer"
                ));
            }
        }
    }
    Ok((tracked_items.len() as f64 / p).round() as u64)
}

// Expose the function above via an importable Python
// extension.
#[pymodule]
fn rust_count_approx(
    m: &Bound<'_, PyModule>
) -> PyResult<()> {
    m.add_function(
        wrap_pyfunction!(count_approx_rust, m)?
    )?;
    Ok(())
}

4. Measuring performance

How does our new version compare as far as speed goes?

First, we install our new package:

$ pip install .

Now we can import our new function from Python code and measure its speed:

from rust_count_approx import count_approx_rust

# See above for definition of WORDS and timeit():
timeit(count_approx_rust, WORDS)

Here’s how the new version compares:

Version Elapsed seconds
Python 0.78
Rust (naive) 0.37

It’s twice as fast. Why isn’t it faster?

This is logically the exact same code as the Python implementation, just implemented in Rust and a little more verbosely. We’re still interacting with Python objects in the same way, iterating over a Python list, and extensively interacting with a Python set. So that part of the code isn’t going to run any differently.

Let’s go faster, part 2: Optimizing the Rust code

Next we’re going to optimize our code in four different ways, all of which I measured separately as improving the performance.

First, we’re going to enable link-time optimization, where the Rust compiler optimizes the code much later in the compilation process, during linking. This means slower compilation, but typically results in faster-running code. We do this by adding the following to Cargo.toml:

[profile.release]
lto = true

Optimizations 2 and 3: Faster random number generation

We also switch from using rand::random(), which uses a thread-local random number generator (RNG), to a RNG we manage ourselves, removing the overhead of thread-local lookups.

At the same time, we switch to using a faster RNG than the default one, the “small” RNG; it’s not quite as secure against hash denial-of-service attacks, but for our purposes that’s probably fine. To do this we add the smallrng feature to the rand crate in Cargo.toml:

[dependencies]
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}

And we’ll need to modify the code, which we’ll see below.

Optimization 4: Store hashes only

Finally, we switch from storing the Python objects in a Python-based set to storing just the hash of the Python object, as calculated by obj.__hash__(). What happens if two objects hash to the same value? The result will be off by one.

We’re already using a probabilistic function, so we’re already OK with slightly wrong results. Clashes should be rare, and having them very rarely be off by one shouldn’t matter if we have many unique values; 99712 isn’t particular more wrong than 99713.

Since we’re no longer storing Python objects, we can switch to using Rust’s std::collections::HashSet, which has some nicer APIs and may be a bit faster. However, the Rust HashSet will want to hash the values… and we don’t want to hash them again, they’re already pre-hashed.

To avoid double-hashing, we also add the Rust crate nohash_hasher:

$ cargo add nohash_hasher

This will allow us to use a Rust-based HashSet that assumes the values it stores are already hashes and can be used in the HashSet as is, without further hashing. The Cargo.toml dependencies section now looks like this:

[dependencies]
nohash-hasher = "0.2.0"
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}

We will also need to update our code to use this, as we’ll see next.

Our new code, with all 4 optimizations

Here’s our updated code:

use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PySequence;
use rand::{rngs::SmallRng, SeedableRng, Rng};
use nohash_hasher::IntSet;

#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
    py: Python<'_>,
    items: &Bound<PySequence>,
    epsilon: f64,
    delta: f64,
) -> PyResult<u64> {
    let mut p = 1.0;
    // Use a set that stores integer values that are assumed
    // to be hashes themselves:
    let mut tracked_items = IntSet::default();
    // Use an RNG we manage ourselves, specifically SmallRng
    // which is faster than the default rng that the rand
    // crate uses:
    let mut rng = SmallRng::from_entropy();
    // Create a closure, similar to a Python lambda:
    let mut random = || rng.gen::<f64>();
    let max_tracked =
        ((12.0 / epsilon.powi(2)) *
         (8.0 * (items.len()? as f64) / delta).log2()
        ).round() as usize;
    for item in items.iter()? {
        // Instead of storing the item, we store its
        // Python-calculated hash (the output of __hash__):
        let hash = item?.hash()?;
        tracked_items.remove(&hash);
        if random() < p {
            tracked_items.insert(hash);
        }
        if tracked_items.len() == max_tracked {
            tracked_items.retain(|_| random() < 0.5);
            p /= 2.0;
            if tracked_items.len() == 0 {
                return Err(PyRuntimeError::new_err(
                    "we got unlucky, no answer"
                ));
        }
    }
    Ok((tracked_items.len() as f64 / p).round() as u64)
}

Again, we can install our updated version by doing:

$ pip install .

Here’s how long our optimized version takes to run in comparison to previous versions:

Version Elapsed seconds
Python 0.78
Rust (naive) 0.37
Rust (optimized) 0.21

Why this isn’t faster, and additional ideas

Why isn’t the Rust code even faster? Our latest optimized version still interacts with a Python list, namely items, and uses Python’s __hash__ API to hash every object. Which means we’re still limited to Python APIs’ speed for those two interactions. If we passed in items as an Arrow column, or a NumPy array of integers, we could probably run that much faster.

More broadly, if we wanted to go even faster, we should also consider using a different algorithm. There are many approximate counting algorithms, and I only picked this one because it was simple and easy to understand, not because it was necessarily particularly fast.

The big picture: Why Rust?

Rust allows us to speed our code by giving us access to a compiled language. But that’s true of C or C++ or Cython.

Unlike those languages, however, Rust has modern tooling, with a built-in package manager and build system:

  • Adding a new dependency is as easy as cargo add <cratename>. A good site to browse available crates is lib.rs.
  • Less visible, but still important: this code will compile on macOS and Windows with no additional work. In contrast, managing C or C++ dependencies across platforms can be very painful.

Rust also has excellent Python integration, with high-level access to Python APIs and easy to use packaging tools like Maturin.

Rust scales both to small projects—as in this article—and to much larger projects, thanks to memory- and thread-safety and to its powerful type system. The Polars project has a generic Rust core of 330K lines, which is then wrapped with more Rust and Python to make a Python extension.

Next time you’re considering speeding up some Python code, give Rust a try. It’s not a simple language, but it’s well worth your time to learn.