はじめに
Python は書きやすいけど遅い。
Rust は速いけど書くのに時間がかかる。
じゃあ、ホットパスだけ Rust で書けばいいじゃん。
PyO3 を使えば、Rust で書いた関数を Python から呼べます。
目次
PyO3 とは
Rust と Python のバインディングを作成するためのクレート。
- Rust の関数を Python から呼べる
- Python のオブジェクトを Rust で扱える
- GIL を制御できる
┌─────────────┐ ┌─────────────┐
│ Python │◀─────▶│ Rust │
│ │ PyO3 │ │
│ (遅い処理) │ │ (速い処理) │
└─────────────┘ └─────────────┘
セットアップ
maturin のインストール
PyO3 のビルドには maturin を使います。
pip install maturin
プロジェクト作成
mkdir rust-python-demo
cd rust-python-demo
maturin init
pyo3 を選択。
生成されるファイル:
rust-python-demo/
├── Cargo.toml
├── pyproject.toml
└── src/
└── lib.rs
Cargo.toml
[package]
name = "rust_python_demo"
version = "0.1.0"
edition = "2021"
[lib]
name = "rust_python_demo"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
Hello World
src/lib.rs
use pyo3::prelude::*;
#[pyfunction]
fn hello() -> PyResult<String> {
Ok("Hello from Rust!".to_string())
}
#[pyfunction]
fn add(a: i64, b: i64) -> PyResult<i64> {
Ok(a + b)
}
#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(hello, m)?)?;
m.add_function(wrap_pyfunction!(add, m)?)?;
Ok(())
}
ビルド
maturin develop
Python から呼び出し
import rust_python_demo
print(rust_python_demo.hello()) # Hello from Rust!
print(rust_python_demo.add(2, 3)) # 5
簡単!
実践:高速化
フィボナッチ数列
Python 版(遅い)
def fib_py(n: int) -> int:
if n < 2:
return n
return fib_py(n - 1) + fib_py(n - 2)
Rust 版(速い)
use pyo3::prelude::*;
#[pyfunction]
fn fib_rust(n: u64) -> PyResult<u64> {
fn fib(n: u64) -> u64 {
if n < 2 {
n
} else {
fib(n - 1) + fib(n - 2)
}
}
Ok(fib(n))
}
#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fib_rust, m)?)?;
Ok(())
}
ベンチマーク
import time
import rust_python_demo
def fib_py(n):
if n < 2:
return n
return fib_py(n - 1) + fib_py(n - 2)
n = 35
# Python
start = time.time()
result_py = fib_py(n)
time_py = time.time() - start
print(f"Python: {result_py} ({time_py:.3f}s)")
# Rust
start = time.time()
result_rust = rust_python_demo.fib_rust(n)
time_rust = time.time() - start
print(f"Rust: {result_rust} ({time_rust:.3f}s)")
print(f"Speedup: {time_py / time_rust:.1f}x")
結果:
Python: 9227465 (2.847s)
Rust: 9227465 (0.034s)
Speedup: 83.7x
83倍速い!
素数判定
#[pyfunction]
fn count_primes(n: u64) -> PyResult<u64> {
fn is_prime(num: u64) -> bool {
if num < 2 {
return false;
}
if num == 2 {
return true;
}
if num % 2 == 0 {
return false;
}
let sqrt = (num as f64).sqrt() as u64;
for i in (3..=sqrt).step_by(2) {
if num % i == 0 {
return false;
}
}
true
}
Ok((2..=n).filter(|&x| is_prime(x)).count() as u64)
}
import rust_python_demo
def count_primes_py(n):
def is_prime(num):
if num < 2:
return False
if num == 2:
return True
if num % 2 == 0:
return False
for i in range(3, int(num**0.5) + 1, 2):
if num % i == 0:
return False
return True
return sum(1 for x in range(2, n + 1) if is_prime(x))
n = 100000
# Python: 約 2.5 秒
# Rust: 約 0.05 秒
# Speedup: 50x
データの受け渡し
リスト(Vec)
use pyo3::prelude::*;
#[pyfunction]
fn sum_list(numbers: Vec<i64>) -> PyResult<i64> {
Ok(numbers.iter().sum())
}
#[pyfunction]
fn double_list(numbers: Vec<i64>) -> PyResult<Vec<i64>> {
Ok(numbers.iter().map(|x| x * 2).collect())
}
import rust_python_demo
print(rust_python_demo.sum_list([1, 2, 3, 4, 5])) # 15
print(rust_python_demo.double_list([1, 2, 3])) # [2, 4, 6]
NumPy 配列
[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use numpy::ndarray::ArrayD;
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::prelude::*;
#[pyfunction]
fn numpy_sum<'py>(
py: Python<'py>,
array: PyReadonlyArrayDyn<'py, f64>,
) -> PyResult<f64> {
let array = array.as_array();
Ok(array.sum())
}
#[pyfunction]
fn numpy_double<'py>(
py: Python<'py>,
array: PyReadonlyArrayDyn<'py, f64>,
) -> Bound<'py, PyArrayDyn<f64>> {
let array = array.as_array();
let result: ArrayD<f64> = array.mapv(|x| x * 2.0);
result.into_pyarray(py)
}
import numpy as np
import rust_python_demo
arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
print(rust_python_demo.numpy_sum(arr)) # 15.0
print(rust_python_demo.numpy_double(arr)) # [2. 4. 6. 8. 10.]
辞書
use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::collections::HashMap;
#[pyfunction]
fn process_dict(data: HashMap<String, i64>) -> PyResult<i64> {
Ok(data.values().sum())
}
#[pyfunction]
fn create_dict(py: Python) -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("a", 1)?;
dict.set_item("b", 2)?;
dict.set_item("c", 3)?;
Ok(dict.into())
}
print(rust_python_demo.process_dict({"x": 10, "y": 20})) # 30
print(rust_python_demo.create_dict()) # {'a': 1, 'b': 2, 'c': 3}
クラスの定義
基本的なクラス
use pyo3::prelude::*;
#[pyclass]
struct Counter {
value: i64,
}
#[pymethods]
impl Counter {
#[new]
fn new(initial: i64) -> Self {
Counter { value: initial }
}
fn increment(&mut self) {
self.value += 1;
}
fn get(&self) -> i64 {
self.value
}
fn __repr__(&self) -> String {
format!("Counter({})", self.value)
}
}
#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Counter>()?;
Ok(())
}
from rust_python_demo import Counter
c = Counter(0)
c.increment()
c.increment()
print(c.get()) # 2
print(c) # Counter(2)
プロパティ
#[pyclass]
struct Point {
#[pyo3(get, set)]
x: f64,
#[pyo3(get, set)]
y: f64,
}
#[pymethods]
impl Point {
#[new]
fn new(x: f64, y: f64) -> Self {
Point { x, y }
}
fn distance(&self, other: &Point) -> f64 {
((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt()
}
}
from rust_python_demo import Point
p1 = Point(0.0, 0.0)
p2 = Point(3.0, 4.0)
print(p1.x, p1.y) # 0.0 0.0
print(p1.distance(p2)) # 5.0
GIL の解放
CPU バウンドな処理で GIL を解放すると、Python 側で他のスレッドが動ける。
use pyo3::prelude::*;
#[pyfunction]
fn heavy_computation(py: Python, n: u64) -> PyResult<u64> {
// GIL を解放して計算
py.allow_threads(|| {
let mut sum = 0u64;
for i in 0..n {
sum = sum.wrapping_add(i);
}
Ok(sum)
})
}
エラーハンドリング
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
#[pyfunction]
fn divide(a: f64, b: f64) -> PyResult<f64> {
if b == 0.0 {
Err(PyValueError::new_err("Division by zero"))
} else {
Ok(a / b)
}
}
import rust_python_demo
try:
result = rust_python_demo.divide(10.0, 0.0)
except ValueError as e:
print(f"Error: {e}") # Error: Division by zero
まとめ
PyO3 の使いどころ
- CPU バウンドな処理 - 計算が重い部分
- ループが多い処理 - Python のループは遅い
- 数値計算 - NumPy でも足りない場合
- 既存の Rust ライブラリを使いたい
パフォーマンス比較
| 処理 | Python | Rust (PyO3) | 高速化 |
|---|---|---|---|
| フィボナッチ(35) | 2.85s | 0.03s | 83x |
| 素数カウント(100k) | 2.5s | 0.05s | 50x |
| 配列の合計(1M) | 0.1s | 0.002s | 50x |
チェックリスト
-
maturinをインストール -
#[pyfunction]で関数を公開 -
#[pyclass]でクラスを公開 -
#[pymodule]でモジュールを定義 -
maturin developでビルド
参考リンク
Python の使いやすさと Rust の速さ、両方のいいとこ取りができます。遅い処理があったら、PyO3 で高速化を検討してみてください!
この記事が役に立ったら、いいね・ストックしてもらえると嬉しいです!