こんにちは。
この記事は、Rustその2 Advent Calendar 2019の24日目の記事として書かれました。
もう25日ですが...。
はじめに
Rust製ライブラリのPythonバインディングを簡単に作成するためのAPIを提供するライブラリ、PyO3を紹介します。前半はチュートリアル的にAPIを概観します。後半はバグの具体例を見ながら「どういう問題があるのか」について簡単に説明します。
時勢のあいさつとか
2018 editionが出てすぐに迎えた2019年ですが、Rustにとって激動の年であったように思います。async/awaitが登場し言語が進化していく一方で、nrc氏・aturon氏がチームを離れたのも印象的な出来事でした。
一方で僕が何をしていたのかというと、なんということでしょうか!だいたいPythonを書いていたわけで、すっかり時代遅れの人になってしまいました。
ただ申し訳程度にRustプロジェクトのメンテナンスも続けており、そのうちの一つがPyO3です。
PythonとC拡張
Pythonというスクリプト言語が、どういうわけか今日では非常に人気があるようです。
知らない方のため雑な説明を試みると、
class MyClass(object):
def __init__(self, name):
self.name = name
def __repr__(self):
return "MyClass()".format(self.name)
print(MyClass("Do I have a name? Really?"))
こんな感じでしょうか。
全てのクラスはobject
を継承する派生クラスです。Rustと異なり、インスタンスは全てガベージコレクションにより管理されます。Rustのtraitのように、継承と関係なく特定のメソッド群を持たせたい場合は、dunderと呼ばれる変な名前のメソッド群を使います。上の例だと、__repr__
がそれですね。
Pythonは残念ながらあまり速くありません。にも関わらず、今日では、従来matlab/Rが担ってきた数値計算・統計処理の分野で非常によく使われているようです。
その理由としてしばしば挙げられるのが、numpyを中心とする高速で大規模(=やたらAPIが多い)な数値計算用ライブラリ群の存在です。numpyはPythonの主要な処理系であるCPythonが提供するC-API1を用いて主にC言語で書かれています。PyTorchなど最近流行のGPU行列演算・深層学習用のライブラリも多くが、pybind11などのライブラリにより、C++で書かれています。
このように、「C/C++製の高速なライブラリにPythonバインディングをつける/最初からPython用に設計する」ということが、機械学習などPythonが支配的な分野では、広く行われています。
すると、C++を葬るために生まれた鬼子たるRust言語を信奉する者としては、「この用途でもC++に勝ちたい(過激派)」と思うわけです。
使い方の例
そこでPyO3の出番です。これを使うととても簡単にPython拡張が作れます。どのくらい簡単かと言うと、正直僕もユーザーとしてはしばらく使っていなかったしよくわからないので、試しにPython拡張を作ってみました。せっかくなので研究で使えるものがいいかなと、Context tree switchingというマルコフ連鎖の次を予測するアルゴリズムをやっつけで実装しました。
これに、以下のようなバインディグをつけました。
use pyo3::prelude::*;
mod cts;
fn convert_result<T>(res: Result<T, cts::CtsError>) -> PyResult<T> {
res.map_err(|e| PyErr::new::<pyo3::exceptions::RuntimeError, _>(format!("{}", e)))
}
#[pyclass]
pub struct SequencePredictor {
context: Vec<u8>,
cts: cts::Cts<u8>,
}
#[pymethods]
impl SequencePredictor {
#[new]
fn new(context_length: usize, symbol_size: u32, initial_symbol: u8) -> PyResult<Self> {
let mut context = Vec::new();
for _ in 0..context_length {
context.push(initial_symbol);
}
convert_result(
cts::Cts::new(context_length, symbol_size, "perks")
.map(|cts| SequencePredictor { context, cts: cts }),
)
}
fn train(&mut self, text: &[u8]) -> PyResult<f64> {
let mut log_prob = 0.0;
for (i, &ch) in text.iter().enumerate() {
let context = &self.context[i..];
log_prob += convert_result(self.cts.update(context, ch))?;
self.context.push(ch);
}
self.context = self.context.split_off(text.len());
Ok(log_prob)
}
fn sample(&mut self, n_samples: usize) -> PyResult<String> {
let ctx_len = self.context.len();
for i in 0..n_samples {
let sampled = convert_result(self.cts.sample(&self.context[i..]))?;
self.context.push(sampled);
}
let result = self.context.split_off(ctx_len);
String::from_utf8(result).map_err(Into::into)
}
}
#[pymodule]
fn skipcts(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<SequencePredictor>()
}
(※開発版のAPIを使用しています)
Rustの構造体をPythonにexposeするには#[pyclass]
というアトリビュートをつければいいです。これだけです。
#[new]
がPythonの__new__
にあたるものですが、今は__init__
はオーバーライドできない仕様になっているので、初期化は全部これでやってもらうかたちになります。あとは適当に#[pymethods]
つきのimplブロックの中にメソッドを書いてあげると、引数・返り値にPython型との変換が定義してあれば、それでもう動きます。
エラーハンドリングはちょっと汚いですね。これは色々問題があって中々修正できていないのですが、まあ場当たり的にやっておきましょう。
書けたので、maturinというビルドツールを使って、ビルドしてみましょう。setuptools-rustというのもありますが、難しいことをしないのであればmaturinがオススメです。
仮想環境に入ります。今回はpipenvを使いましたが、別になんでもいいです。
❯ git clone https://github.com/kngwyu/skipcts-rs.git
❯ cd skipcts-rs
❯ pipenv shell
❯ pip install maturin
pyproject.toml
を書きます。
[build-system]
requires = ["maturin"]
build-backend = "maturin"
仮想環境にインストールするにはmaturin develop
コマンドを使います。
❯ maturin develop 3401ms 2019年12月25日 01時46分20秒
🔗 Found pyo3 bindings
🐍 Found CPython 3.8 at python
Compiling pyo3 v0.8.4 (https://github.com/kngwyu/pyo3.git?branch=pyclass-new-layout#58590393)
Compiling skipcts-rs v0.1.0 (/home/mio_h/Programs/skipcts-rs)
Finished dev [unoptimized + debuginfo] target(s) in 3.15s
ビルドできたっぽいですね。使ってみましょう。
❯ pip install ipython
❯ ipython
そこはかとなくDNAっぽい文字列を学習させてみます。
In [1]: from skipcts import SequencePredictor
In [2]: s = SequencePredictor(3, 4, b'a'[0])
In [3]: s
Out[3]: <SequencePredictor at 0x7f6cbce67b70>
In [4]: s.train(b'acgacggcaacg')
Out[4]: -16.700540691964783
In [5]: s.sample(10)
Out[5]: 'gcggcggaac'
それっぽい文字列が出てきました。
せっかくクリスマスなので、クリスマス・キャロルを学習させてみましょう。
In [7]: import requests
In [8]: ch_carol = requests.get("https://www.gutenberg.org/files/46/46-0.txt").text
In [9]: ch_carol[340: 420]
Out[9]: ' Christmas Carol\r\n A Ghost Story of Christmas\r\n\r\nAuthor: Charles Dickens\r\n'
In [10]: s = SequencePredictor(10, 100, b'a'[0])
In [11]: s.train(ch_carol.encode("utf-8"))
Out[11]: nan
In [12]: s.sample(100)
Out[12]: '\r\nAnd within thought\r\nhad the before the stood sat lenty than hung as when the copiece went with fir'
意味不明な文字列が出てきました。ロスもNANになっているし(バグかなあ)、もの悲しいですね。全然クリスマスだという感じはしませんね。
それはそれとして、なんだかすごく簡単にPython拡張が作れる気がしないでしょうか。
クラスの継承は、以下のように書けます。
#[pyclass(extends=SequencePredictor)]
pub struct PredictorExt {
contexts: Vec<usize>,
}
#[pymethods]
impl PredictorExt {
#[new]
fn new(
context_length: usize,
symbol_size: u32,
initial_symbol: u8,
) -> PyResult<pyo3::PyClassInitializer<Self>> {
let mut init = pyo3::PyClassInitializer::from_value(PredictorExt { contexts: vec![] });
let super_ = SequencePredictor::new(context_length, symbol_size, initial_symbol)?;
init.get_super().init(super_);
Ok(init)
}
}
PyClassInitializer
というタイプが面倒なものが出てきましたが、これも中々簡単な気がします。
親クラスを初期化しない場合、実行時エラーになるので気をつけてください(本当はコンパイル時に防ぎたいのですが...)。
他の機能については、ガイドを参照してください。
ただ基本的に、Pythonと接続するレイヤーはなるべく薄く作って、簡単な機能だけ使うのがいいと思います。
そちらの方が設計の見通しがよいし、中身(proc-macro以外のもの)の方は今後かなりAPIが変わるかもしれないので...。
また、基本的にRustオブジェクトとPythonオブジェクトの変換は全てコピーにより実現されることに注意してください。参照だけPython側にexposeする場合は適宜Rc、Boxなどのスマートポインタを使う必要があります。
唯一の例外はrust-numpyのIntoPyArrayによる変換で、これはゼロコピーで配列をPythonに渡せます。パフォーマンスが必要な時には基本的にオススメできる機能です。
問題点の例
というわけで、だいたい使い方を理解していただけたのではないかと思います。
簡単に書けるのは美徳ではありますが、PyO3は主に安全性の面において、かなり問題の多いライブラリでした。今でも、色々言われることがありますし、SIGSEGVを踏みまくってデバッグすることもあります。
もっとも印象に残っているバグは、やはりこれでしょうか。ややこしいバグですけれど、せっかくなので解説してみましょう。
まず、PyO3がPythonのlistやtupleなどのネイティブ型をどのように扱っているかを解説します。
ふむふむ。PyList::new
は&PyList
を返すのか?
ん? &
?
ここで勘のいい読者の方は、「これは何の参照なんだろう]とお思いになったことでしょう。
PyListの定義はラフに言うとこんな感じです。
#[repr(transparent)]
pub struct PyList(*mut pyo3::ffi::PyObject);
そう!PyList自体がすでに、「Pythonが管理しているヒープ上にあるオブジェクトの先頭を指すポインタ」なんですよね2。
じゃあこの参照はどこを指してるんだよ、というと、この中です。
は?とお思いになった方も多いのではないでしょうか?
僕も初めて見た時は、は?と思いました。
実はPyO3はPythonのGlobal Interpreter Lock(GIL)を「取得していること」を示すRAIIタイプであるGILGuardに無理矢理オブジェクトのライフタイムを紐付けるため、一度ポインタを内部のプールに格納して、そこの参照を返します。これは一部の特殊なオブジェクトについては必要なのですが、95%のユースケースについては無駄だと僕は考えています。なんなら将来的にはなくなるかもしれません。
上記のバグは、このポインタがVec
に格納されていたため、reallocateされたときにアドレスが変わって落ちるというとんでもない代物でした。学会の帰りの高速バスで、バス酔いで吐きそうになりながらパッチを書いたのを覚えています。
今はここまでひどいバグはありませんが、まずい設計を引きずっている部分が随所にあります。
正直、僕は未定義動作ゼロとかそういうレベルの安全性を担保するつもりはありません。本当はUnsafeCellが必要だけどFFIをまたぐから大丈夫みたいなUBは、正直絵に描いた餅というか、そんなに解決が必要だと思いません。未初期化についてもRustはすごく厳しいですが、メモリ管理をPythonにやらせる以上、無理な部分がかなりあります。
しかし普通に使っていてSIGSEGVで落ちるというのは、ゼロにしたいなと思います。
これから
PyO3がフォークした経緯について正直あまり知らないのですが、actix-webで有名(?)なfafhrd91さんがフォークしたのが2年くらい前。その後konstinさんがメンテナを引き継いで、僕がメンテナになって1年ちょっとでしょうか。今まではたまにパッチを書くくらいであまり真剣に取り組んでいなかったのですが、あまりに「stableで動くようにしろ」という声が多いので最近は毎週末必ず作業しています。
またkonstinさん曰く
While I'm still occasionally working on maturin and sometimes comment on
an issue or a pull request, I'm not actively maintaining pyo3 anymore,
とのことですので、しばらく僕がメインでやっていくことになります。
正直OSSをやっていて得をしたことはないのですが、まあさすがにもうちょっとクオリティの高いものを作りたいなというか、1.0レベルまでは上げたいので...。
おわりに
最近OSSの作業はRustオンラインもくもく会を利用してやっています。主催のtermoshttさんおよびrust-jpスラック運営の方々に感謝します。
また、PyO3に関してもスラックで質問してくだされば答えるので、よろしくお願いします。
なんか検索していると、「動かない」「枯れていない」「終わり」みたいな記事をたまに見かけます。
もちろん十分な資料を提供していない僕の側も悪いとは思います。
しかし、聞かれれば答えるのに、「わかんなかった」という結論にされてしまうのは、僕は寂しいです。
ディスコミニュケーションを感じます。
わからないことがあればどんどん質問してください。「調べたけどよくわかんなかった」みたいなよくわからん情報をネットに残す時代は、もう終わりにしましょう!
これからは**「質問したらわかりました!わかったことを書きたいと思います!」の時代にしていきましょう!**
おわり。
アペンディクス
Pythonのオブジェクト指向
CPythonの実装においてオブジェクト指向は、タグつきUnion等を用いるのではなく、「先頭一致の規約」により実現されています。
基底クラスは参照カウントと型オブジェクトへのポインタを持ちます。Cはわかりにくいので、RustのFFIで書きます。
#[repr(C)]
pub struct PyObject {
pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject,
}
これと先頭が一致していると、派生クラスになれます。
#[repr(C)]
pub struct PyListObject {
pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject,
pub ob_size: Py_ssize_t,
pub ob_item: *mut *mut PyObject,
pub allocated: Py_ssize_t,
}
オブジェクト同士の変換は、単に*mut PyObject
と*mut PyListObject
の変換により行われます。