はじめに
Rustで自動微分するものを作ってみました。よければスターとかissueとかをしてくれるととても喜びます。(友達に言ってみたけど見事にスルーされた。悲しい)
また、この実装は気持ち悪い。この書き方はRustらしくないなどご指摘もしていただきたいです。よろしくお願いします。(でもマサカリは投げないで)
また、パフォーマンスを求めるのであればしっかりとしたクレートも存在しています。
使い方
私が作成したクレートの簡単な使い方は以下のコードの様になります。
現在実装してある、計算グラフの演算方法は、行列積(matmul
)、加算(add
)、要素積or積(product
)です。これらの方法は、サンプルコードのように実装を行えば計算グラフを構築することができます。
use autograd::function::*;
use autograd::node::Node;
use autograd::tensor::Tensor;
fn main() {
// 勾配情報を持つtensorの定義(この場合はスカラー)
let x = Tensor::new(&[]);
// y = x * x + xの計算グラフを定義する
let y = add(&product(&x, &x), &x);
// 計算グラフの購買情報をリセット
y.zero_grad();
// 購買情報を伝搬する前にグラフの最も下のノードの勾配をセットする
y.set_grad();
// 計算グラフの計算を行う
y.forward();
// 計算グラフの勾配を計算する
y. backward();
}
基本的な構造
※この成果物的なものはなるべく他のクレートを見ないで自分のコーディング能力だけで実装したものです。そのため、一部をのぞき全て自分で実装しています。クソコードの可能性があります。
行列計算ライブラリ
行列計算ライブラリはRust用のものも存在します。私はndarray
を利用しました。
勾配情報を持つ構造体
自動微分では、勾配情報を持つためそれ専用の構造体を定義しました。
#[derive(Debug)]
pub struct Tensor<'a, T: Float> {
pub input: UnsafeCell<NdArray<T>>,
pub grad: UnsafeCell<NdArray<T>>,
pub dim: &'a [usize]
}
dim
はtensorの次元を表し, input
はこのテンソルが持つ計算するための値orテンソルを持ちgrad
は勾配を持ちます。また、構造体自体はイミュータブルでもinput
とgrad
は更新する必要があります。そこでUnsafeCell
を用いました。
この構造体は、自動微分をする際の全てで用いられる構造体です。
Nodeトレイト
計算グラフのノードとしてのトレイトを実装します。
pub trait Node<'a, F: Float> {
fn get_input(&self) -> *mut NdArray<F>;
fn get_grad(&self) -> *mut NdArray<F>;
fn reset_grad(&self);
fn shape(&self) -> &'a [usize];
fn forward(&self) ;
fn backward(&self);
fn zero_grad(&self);
fn set_grad(&self);
}
このトレイトは全ての計算グラフの操作を実装したトレイトになっています。
-
get_input
: コードの通り、Nodeにおける値の生ポインタを返します。 -
get_grad
: コードの通り、Nodeにおける勾配の生ポインタを返します。 -
reset_grad
: ノードの勾配をゼロにします。 -
shape
: ノードのテンソルの次元を返します。 -
forward
: ノードにおける演算or計算グラフ自体の計算を行います。 -
backward
: ノードのor 計算グラフ全体の勾配を計算します。 -
zero_grad
: グラフの勾配をゼロにします。 -
set_grad
: グラフの一番最後のノードの勾配を1に設定します。
計算グラフの構築方法
グラフの演算処理(add
, matmul
, product
)はノードとノード(Node
トレイトが実装されているもの)をつなげた計算グラフにおけるノードを出力し、ノード同士を繋ぎながら計算グラフを構築していきます。
let x = Tensor::new(&[19, 19]);
let y = Tensor::new(&[19,19]);
// xとyの加算を行う計算グラフ
let add_x_y = add(&x, &y);
let z = Tensor::new(&[19,19]);
// add_x_yとzの積算をする計算グラフ
let product_z_add_x_y = product(&z, &add_x_y);
上のコードでは、計算グラフは
x
\
+ (add_x_y)
/ \
y \
×-----> product_z_add_x_y
/
/
/
/
/
z
が構築されます。
計算グラフの計算、勾配計算方法
グラフの演算処理が返り値として得られたノードはノードとノードを繋いだ計算グラフそのものと見なすことができます。その計算グラフを実行するにはサンプルコードに書いてあるとおり、forward()
を行えばよいです。また、一番最後のノードの勾配を1に設定する、set_grad
を行い、backward()
を行えば全てのノードの勾配を計算することができます。
今後の課題
- GPUに対応させる
- functionに入力されるテンソルの次元チェッカーを実装する(forwardする前にエラーを吐かせたい)
- ニューラルネットでよくあるものを実装したい