Rust Advent Calendar 2017 6日目の記事です。
rust-autograd というライブラリを作りました。https://github.com/raskr/rust-autograd
もともとニューラルネットのライブラリのつもりだったんですが、GPUの知識がなさすぎて無理だと気づき、一般的な名前にしました。一応ドキュメント(兼テスト)もあります。例えば行列積の計算はこんな感じです:
extern crate autograd as ag;
let a: ag::Tensor = ag::zeros(&[4, 2]);
let b: ag::Tensor = ag::zeros(&[2, 3]);
let c: ag::Tensor = ag::matmul(&a, &b);
println!("{}", c.eval(&[]).unwrap());
// [[0, 0, 0],
// [0, 0, 0],
// [0, 0, 0],
// [0, 0, 0]]
とりあえず以下のようことができます:
- 多次元配列の計算. バックエンドは rust-ndarray
- 勾配の計算. N階微分に対応してます
- 勾配降下法でのパラメータ最適化
100% Rust なので計算グラフを実行時にコンパイルしたりはしないため, 手軽に動かせます。逆にできないことは:
- GPUで実行
Convolution系
つまり非常に残念なことにディープラーニングは絶望的ですが、簡単な手書き文字認識くらいならできます。RNN も一応動きます。以下は簡単な勾配計算の例です。API の実装は先人たちを大いに参考にさせてもらいました。(ほんとうに申し訳ありません)
extern crate ndarray;
extern crate autograd as ag;
// z = 2x^2 + 3y + 1 の偏微分を計算する
let ref x = ag::placeholder(&[]);
let ref y = ag::placeholder(&[]);
let ref z = 2*x*x + 3*y + 1;
// dz/dy
let ref gy = ag::grad(&[z], &[y])[0];
// dz/dx
let ref gx = ag::grad(&[z], &[x])[0];
// ddz/dx (differentiates `z` with `x` again)
let ref ggx = ag::grad(&[gx], &[x])[0];
// evaluation of symbolic gradients
println!("{:?}", gy.eval(&[])); // => Ok(3.)
println!("{:?}", ggx.eval(&[])); // => Ok(4.)
// dz/dx requires to fill the placeholder `x`
println!("{:?}", gx.eval(&[(x, ndarray::arr0(2.))])); // => Ok(8.)
上記はただのスカラー値の四則演算の例ですが、"ここ"に定義してある命令はすべて(微分可能なら)微分できます。 TF 同様計算グラフベースなのでN階微分に対応してます。 TF にあるが rust-ndarray にない関数もいくつか実装してます (e.g. transpose, gather, tensordot, argmax, ...)
rust-ndarray
実際の数値計算は rust-ndarray というクレートに依存しています。これは numpy のようなものかと期待しましたが、少し違います。というか
Rust よくわからない状態で始めたのもあってかなり面食らいました。特徴としては:
- Pure Rust なので結構速い(と思う)
- map, fold などのコレクション系の高階関数が提供されていて、抽象的な API になっている (Rustのイテレータ+高階関数はちゃんと使えば最適化されるので速いっぽい)
- 他にも Zip や AxisIter, Read-only View のようなRustっぽい面白い API がある
- ndarray のランクは型引数として受け取り、静的に扱う (C++ の Eigen みたいな)
- numpy や TF に慣れている人が当たり前に使う関数がないことが結構ある.
- 基本的な二項演算やブロードキャストなどは違和感なく使える
などなどです。 numpy なんかに比べたらAPIの充実具合は流石にアレですが, 基本機能は揃ってると思います。 rust-ndarray 以外には rayon という OpenMP みたいなライブラリを使ったデータ並列も一部してますが外部依存はほとんどないです (というか依存は Cargo が勝手に解決してくれるのでどうでもいい)
autograd::op::Op を実装する
その rust-ndarray と autograd を使って勾配計算に対応したオペレーションを実装する例は以下です:
extern crate ndarray;
extern crate autograd as ag;
// めんどくさいので今のところ f32 に限定してます。
type NdArray = ndarray::Array<f32, ndarray::IxDyn>
// シグモイド関数を例に
struct Sigmoid;
// Op トレイトを実装する
impl ag::op::Op for Sigmoid {
fn name(&self) -> &str
{
"Sigmoid"
}
// 実際に Sigmoid 関数の出力を計算するメソッド。
fn compute(&self, ctx: ag::runtime::OpComputeContext)
-> Vec<Result<NdArray, ag::op::ComputeError>>
{
let xs = ctx.grab_inputs();
let x = xs[0];
// ndarray::Array::mapv で x の全要素の sigmoid を計算する。
// 今はしてませんが rayon を使えば簡単に並列化させることもできます。
let y = x.mapv(|a| ((a * 0.5).tanh() * 0.5) + 0.5);
vec![Ok(y)]
}
// "gxs=最終出力に対するxsの勾配" を定義するメソッド。
// gy=yの勾配, xs=入力, y=出力 がそれぞれ `Tensor` として渡ってくるので、
// それらと微分の連鎖律をうまく使って gxs を定義します。 xs が勾配を必要としてないなら
// None を返せばいいです。
fn grad(&self, gy: &ag::Tensor, xs: &[&ag::Tensor], y: &ag::Tensor)
-> Vec<Option<ag::Tensor>>
{
// sigmoid の導関数は sigmoid-sigmoid^2 つまり y-y^2 です。
// 微分の連鎖律を途切れさせたくないのでそれを gy にかけて return します。
let gx = gy * (y - ag::square(y));
vec![Some(gx)]
}
}
他のNNライブラリの実装もだいたいこんな感じになってると思いますが、とても簡単に新しいオペレーションを定義できました。あとはこれを使うためのヘルパーを書いて終了です:
fn sigmoid(x: &ag::Tensor) -> ag::Tensor
{
ag::Tensor::builder()
.set_inputs(vec![x])
.set_shape(x.shape())
.build(Sigmoid)
}
let a = ag::zeros(&[3]);
let y = sigmoid(&a);
y.eval(...
繰り返しになりますが、メジャーな関数は実装してあります。(頑張りました...)
実装上つらかったところ
分かってはいましたが、計算グラフ上のノードは基本的に共有されるためムーブセマンティクスと相性が悪く、とはいえ新しいノードが内部でポコポコ生成されるので参照のライフタイムの管理も難しく、結局参照カウントにメモリ管理を任せてます。 逆に言えば参照カウントを適当に使ってれば何も神経質にならずに書けるのですが、Rc::clone
が予想以上に現れたので微妙な気持ちになりました。まあこれはしょうがない。
また、グラフのノード/計算結果を一元管理してる都合上, ノードの出力である ndarray のランク(Generic type) もndarray::IxDyn
という動的なものに統一しないといけないというのもあります。これも今のところどうしようもないです。
まとめ
- Rust初心者が rust-ndarray 上で動く自動微分を作ってみた
- MNIST がそこそこな速度で動くことは確認した
- rust-ndarray 結構使えるかも
- Rust でグラフ構造つらい
自分でも使ってないのでどれくらい使えるのか知りませんが、少なくとも rust-ndarray の方は良く出来てると思います。 他にも cuDNNのライブラリや cudaカーネルを書けるライブラリもあるみたいだし、意外と Rust で機械学習できそうな気はします。autograd では計算グラフの構築と評価、勾配計算は基本的に ndarray に依存しない作りにしているので頑張ればGPUでも動かせそうなのですが、今のところ知識とモチベーションが足りてないです... 以上です。