4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

はじめに

Python は書きやすいけど遅い。
Rust は速いけど書くのに時間がかかる。

じゃあ、ホットパスだけ Rust で書けばいいじゃん。

PyO3 を使えば、Rust で書いた関数を Python から呼べます。

目次

  1. PyO3 とは
  2. セットアップ
  3. Hello World
  4. 実践:高速化
  5. データの受け渡し
  6. クラスの定義
  7. まとめ

PyO3 とは

Rust と Python のバインディングを作成するためのクレート。

  • Rust の関数を Python から呼べる
  • Python のオブジェクトを Rust で扱える
  • GIL を制御できる
┌─────────────┐       ┌─────────────┐
│   Python    │◀─────▶│    Rust     │
│             │ PyO3  │             │
│  (遅い処理)  │       │  (速い処理)  │
└─────────────┘       └─────────────┘

セットアップ

maturin のインストール

PyO3 のビルドには maturin を使います。

pip install maturin

プロジェクト作成

mkdir rust-python-demo
cd rust-python-demo
maturin init

pyo3 を選択。

生成されるファイル:

rust-python-demo/
├── Cargo.toml
├── pyproject.toml
└── src/
    └── lib.rs

Cargo.toml

[package]
name = "rust_python_demo"
version = "0.1.0"
edition = "2021"

[lib]
name = "rust_python_demo"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }

Hello World

src/lib.rs

use pyo3::prelude::*;

#[pyfunction]
fn hello() -> PyResult<String> {
    Ok("Hello from Rust!".to_string())
}

#[pyfunction]
fn add(a: i64, b: i64) -> PyResult<i64> {
    Ok(a + b)
}

#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(hello, m)?)?;
    m.add_function(wrap_pyfunction!(add, m)?)?;
    Ok(())
}

ビルド

maturin develop

Python から呼び出し

import rust_python_demo

print(rust_python_demo.hello())  # Hello from Rust!
print(rust_python_demo.add(2, 3))  # 5

簡単!

実践:高速化

フィボナッチ数列

Python 版(遅い)

def fib_py(n: int) -> int:
    if n < 2:
        return n
    return fib_py(n - 1) + fib_py(n - 2)

Rust 版(速い)

use pyo3::prelude::*;

#[pyfunction]
fn fib_rust(n: u64) -> PyResult<u64> {
    fn fib(n: u64) -> u64 {
        if n < 2 {
            n
        } else {
            fib(n - 1) + fib(n - 2)
        }
    }
    Ok(fib(n))
}

#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fib_rust, m)?)?;
    Ok(())
}

ベンチマーク

import time
import rust_python_demo

def fib_py(n):
    if n < 2:
        return n
    return fib_py(n - 1) + fib_py(n - 2)

n = 35

# Python
start = time.time()
result_py = fib_py(n)
time_py = time.time() - start
print(f"Python: {result_py} ({time_py:.3f}s)")

# Rust
start = time.time()
result_rust = rust_python_demo.fib_rust(n)
time_rust = time.time() - start
print(f"Rust: {result_rust} ({time_rust:.3f}s)")

print(f"Speedup: {time_py / time_rust:.1f}x")

結果:

Python: 9227465 (2.847s)
Rust: 9227465 (0.034s)
Speedup: 83.7x

83倍速い!

素数判定

#[pyfunction]
fn count_primes(n: u64) -> PyResult<u64> {
    fn is_prime(num: u64) -> bool {
        if num < 2 {
            return false;
        }
        if num == 2 {
            return true;
        }
        if num % 2 == 0 {
            return false;
        }
        let sqrt = (num as f64).sqrt() as u64;
        for i in (3..=sqrt).step_by(2) {
            if num % i == 0 {
                return false;
            }
        }
        true
    }
    
    Ok((2..=n).filter(|&x| is_prime(x)).count() as u64)
}
import rust_python_demo

def count_primes_py(n):
    def is_prime(num):
        if num < 2:
            return False
        if num == 2:
            return True
        if num % 2 == 0:
            return False
        for i in range(3, int(num**0.5) + 1, 2):
            if num % i == 0:
                return False
        return True
    return sum(1 for x in range(2, n + 1) if is_prime(x))

n = 100000

# Python: 約 2.5 秒
# Rust: 約 0.05 秒
# Speedup: 50x

データの受け渡し

リスト(Vec)

use pyo3::prelude::*;

#[pyfunction]
fn sum_list(numbers: Vec<i64>) -> PyResult<i64> {
    Ok(numbers.iter().sum())
}

#[pyfunction]
fn double_list(numbers: Vec<i64>) -> PyResult<Vec<i64>> {
    Ok(numbers.iter().map(|x| x * 2).collect())
}
import rust_python_demo

print(rust_python_demo.sum_list([1, 2, 3, 4, 5]))  # 15
print(rust_python_demo.double_list([1, 2, 3]))  # [2, 4, 6]

NumPy 配列

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use numpy::ndarray::ArrayD;
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::prelude::*;

#[pyfunction]
fn numpy_sum<'py>(
    py: Python<'py>,
    array: PyReadonlyArrayDyn<'py, f64>,
) -> PyResult<f64> {
    let array = array.as_array();
    Ok(array.sum())
}

#[pyfunction]
fn numpy_double<'py>(
    py: Python<'py>,
    array: PyReadonlyArrayDyn<'py, f64>,
) -> Bound<'py, PyArrayDyn<f64>> {
    let array = array.as_array();
    let result: ArrayD<f64> = array.mapv(|x| x * 2.0);
    result.into_pyarray(py)
}
import numpy as np
import rust_python_demo

arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
print(rust_python_demo.numpy_sum(arr))  # 15.0
print(rust_python_demo.numpy_double(arr))  # [2. 4. 6. 8. 10.]

辞書

use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::collections::HashMap;

#[pyfunction]
fn process_dict(data: HashMap<String, i64>) -> PyResult<i64> {
    Ok(data.values().sum())
}

#[pyfunction]
fn create_dict(py: Python) -> PyResult<Py<PyDict>> {
    let dict = PyDict::new(py);
    dict.set_item("a", 1)?;
    dict.set_item("b", 2)?;
    dict.set_item("c", 3)?;
    Ok(dict.into())
}
print(rust_python_demo.process_dict({"x": 10, "y": 20}))  # 30
print(rust_python_demo.create_dict())  # {'a': 1, 'b': 2, 'c': 3}

クラスの定義

基本的なクラス

use pyo3::prelude::*;

#[pyclass]
struct Counter {
    value: i64,
}

#[pymethods]
impl Counter {
    #[new]
    fn new(initial: i64) -> Self {
        Counter { value: initial }
    }
    
    fn increment(&mut self) {
        self.value += 1;
    }
    
    fn get(&self) -> i64 {
        self.value
    }
    
    fn __repr__(&self) -> String {
        format!("Counter({})", self.value)
    }
}

#[pymodule]
fn rust_python_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_class::<Counter>()?;
    Ok(())
}
from rust_python_demo import Counter

c = Counter(0)
c.increment()
c.increment()
print(c.get())  # 2
print(c)  # Counter(2)

プロパティ

#[pyclass]
struct Point {
    #[pyo3(get, set)]
    x: f64,
    #[pyo3(get, set)]
    y: f64,
}

#[pymethods]
impl Point {
    #[new]
    fn new(x: f64, y: f64) -> Self {
        Point { x, y }
    }
    
    fn distance(&self, other: &Point) -> f64 {
        ((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt()
    }
}
from rust_python_demo import Point

p1 = Point(0.0, 0.0)
p2 = Point(3.0, 4.0)
print(p1.x, p1.y)  # 0.0 0.0
print(p1.distance(p2))  # 5.0

GIL の解放

CPU バウンドな処理で GIL を解放すると、Python 側で他のスレッドが動ける。

use pyo3::prelude::*;

#[pyfunction]
fn heavy_computation(py: Python, n: u64) -> PyResult<u64> {
    // GIL を解放して計算
    py.allow_threads(|| {
        let mut sum = 0u64;
        for i in 0..n {
            sum = sum.wrapping_add(i);
        }
        Ok(sum)
    })
}

エラーハンドリング

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

#[pyfunction]
fn divide(a: f64, b: f64) -> PyResult<f64> {
    if b == 0.0 {
        Err(PyValueError::new_err("Division by zero"))
    } else {
        Ok(a / b)
    }
}
import rust_python_demo

try:
    result = rust_python_demo.divide(10.0, 0.0)
except ValueError as e:
    print(f"Error: {e}")  # Error: Division by zero

まとめ

PyO3 の使いどころ

  1. CPU バウンドな処理 - 計算が重い部分
  2. ループが多い処理 - Python のループは遅い
  3. 数値計算 - NumPy でも足りない場合
  4. 既存の Rust ライブラリを使いたい

パフォーマンス比較

処理 Python Rust (PyO3) 高速化
フィボナッチ(35) 2.85s 0.03s 83x
素数カウント(100k) 2.5s 0.05s 50x
配列の合計(1M) 0.1s 0.002s 50x

チェックリスト

  • maturin をインストール
  • #[pyfunction] で関数を公開
  • #[pyclass] でクラスを公開
  • #[pymodule] でモジュールを定義
  • maturin develop でビルド

参考リンク

Python の使いやすさと Rust の速さ、両方のいいとこ取りができます。遅い処理があったら、PyO3 で高速化を検討してみてください!

この記事が役に立ったら、いいね・ストックしてもらえると嬉しいです!

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?