LoginSignup
20
18

More than 1 year has passed since last update.

Rustでディープラーニングライブラリをフルスクラッチで実装した

Posted at

この記事では私が作ったRust製ディープラーニングライブラリminiatureを紹介します.

はじめに

miniatureは私がRustの勉強がてらに作ってみたものなので,実用的に何かに使えるものではありません.しかし,非常にシンプルな設計なので,TensorFlowやPyTorchなどのディープラーニングライブラリが裏で何をやっているかを勉強したい人にとってはいい参考になるのではないかと思います.また,拡張も容易なので大学の授業で使うのにも向いているかもしれません(論文読んで実際にレイヤーや最適化手法を実装するなど).しかも,Pythonで同じことをするとnumpyを使ってサクッと出来てしまいますが,Rustではゴリゴリ行列計算を書かせることができるので非常に教育的です.

設計がシンプルと言いましたが,ソニー製のディープラーニングライブラリであるnnablaに頻繁にプルリクを送っていたこともあり,設計する際に大きな影響を受けています.nnablaはminiatureに比べるととても大きなソフトウェアですが,かなり綺麗に設計されているので,興味がある人は是非覗いてみてください.

また,筆者がRust素人のため,Rustに関する記述が誤っている場合があります.その時は筆者も大変勉強になるのでご指摘ください.

特徴

miniatureの特徴は大きく分けて3つあります.

  • ChainerやPyTorchのようにdefine-by-runでグラフを作る
  • 最近のディープラーニングライブラリと同じ感覚で書ける
  • 外部のライブラリを利用していないのでスタンドアローンで利用できる
use miniature::functions as F;
use miniature::graph::backward;
use miniature::optimizers as S;
use miniature::parametric_functions as PF;
use miniature::variable::Variable;

use std::rc::Rc;
use std::cell::RefCell;

fn main() {
    // define layers
    let fc1 = PF::linear(28 * 28, 256);
    let fc2 = PF::linear(256, 256);
    let fc3 = PF::linear(256, 10);

    // define optimizer
    let mut optim = S::adam(0.001, (0.9, 0.999), 1e-8);
    optim.set_params(fc1.get_params());
    optim.set_params(fc2.get_params());
    optim.set_params(fc3.get_params());

    let x = Rc::new(RefCell::new(Variable::rand(vec![32, 28 * 28])));
    let t = Rc::new(RefCell::new(Variable::rand(vec![32])));

    // forward
    let h1 = F::relu(fc1.call(x));
    let h2 = F::relu(fc2.call(h1));
    let y = fc3.call(h2);

    // loss
    let loss = F::cross_entropy_loss(y, F::onehot(t, 10));

    // update
    optim.zero_grad();
    backward(loss);
    optim.update();
}

主要な登場人物を説明します.

miniature::variable::Variable

VariableはPyTorchでいうtf.Tensor,TensorFlowでいうtf.Variableに相当します.

Variableはテンソルや勾配を格納しています.さらに,バックプロパゲーションで出力から入力を辿るためにparentにどのレイヤーから出力されたものかを記録します.

pub struct Variable {
    pub parent: Option<Rc<RefCell<CgFunction>>>,
    pub shape: Vec<u32>,
    pub data: Vec<f32>,
    pub grad: Vec<f32>,
    pub need_grad: bool,
}

ここで注目して欲しいのが,datagradが一次元配列であるところです.
多次元配列も本質的にメモリの上では一次元配列と同じであるため,テンソルは一次元で管理して,多次元配列としての計算が必要な時にshapeの情報を利用します.

miniatureでは基本的にVariableがレイヤーやオプティマイザなど複数箇所で参照されるため,常にRc<RefCell<Variable>>として利用します.

miniature::functions

miniature::functionsは上記の例のようにさまざまなテンソル操作を提供しています.

例えばreluの順伝播と逆伝播の計算は以下のような実装になっています.

pub struct ReLu {}

impl FunctionImpl for ReLu {
    fn forward_impl(
        &mut self,
        inputs: &[Rc<RefCell<Variable>>],
        outputs: &[Rc<RefCell<Variable>>],
    ) {
        let x = inputs[0].borrow();
        let mut output = outputs[0].borrow_mut();

        for i in 0..x.size() as usize {
            output.data[i] = if x.data[i] > 0.0 { x.data[i] } else { 0.0 };
        }
    }

    fn backward_impl(
        &mut self,
        inputs: &[Rc<RefCell<Variable>>],
        outputs: &[Rc<RefCell<Variable>>],
    ) {
        let mut x = inputs[0].borrow_mut();
        let output = outputs[0].borrow();

        for i in 0..x.size() as usize {
            x.grad[i] += if x.data[i] > 0.0 { output.grad[i] } else { 0.0 };
        }
    }
}

ここでFunctionImplはインターフェースを揃えるためのtraitですが,Pythonライブラリのような感覚で利用できるようするためにユーザがF::reluで呼び出しているのは以下のようなラッパー関数になります.

pub fn relu(x: Rc<RefCell<Variable>>) -> Rc<RefCell<Variable>> {
    let output = Rc::new(RefCell::new(Variable::new(x.borrow().shape.clone())));
    let function = Box::new(ReLu {});
    .
    .
    .
    output
}

これはユーザ側でtraitの関数を呼び出す際に,useを使ってtraitを読み込まなければならないのを避けるための設計です.

miniature::parametric_functions

miniature::parametric_functionsでは重みやバイアスのような学習パラメータを含むレイヤーを提供しています(とは言っても,執筆時点では全結合層のみが実装されています).

miniature::functionsの方でテンソル操作が実装されているので,基本的にこのレイヤーではそれらを呼び出すだけになります.

pub struct Linear {
    weight: Rc<RefCell<Variable>>,
    bias: Rc<RefCell<Variable>>,
    out_size: u32,
}

impl Linear {
    .
    .
    .
    pub fn call(&self, x: Rc<RefCell<Variable>>) -> Rc<RefCell<Variable>> {
        let batch_size = x.borrow().shape[0];
        let h = F::matmul(x, self.weight.clone());
        let broadcasted_bias = F::broadcast(self.bias.clone(), vec![batch_size, self.out_size]);
        F::add(h, broadcasted_bias)
    }
    .
    .
    .
}

miniature::optimizers

miniature::optimizersはSGDやAdamのような最適化手法を提供します(執筆時点ではSGDとAdamのみ実装しています).

参考までに,SGDはこのような実装になっています.

pub struct Sgd {
    pub lr: f32,
}

impl OptimizerImpl for Sgd {
    fn update(&mut self, params: &[Rc<RefCell<Variable>>]) {
        for param in params {
            let mut param = param.borrow_mut();
            for j in 0..param.size() as usize {
                param.data[j] -= self.lr * param.grad[j];
            }
        }
    }
}

ユーザ側のコードでOptimizerImplのtraitをuseで持ってこないように,miniature::functionsと同様なラッパー関数経由で呼び出します.

MNIST学習

miniatureを使って実際にMNISTを学習することが出来ます.

READMEにしたがって以下のコードをとりあえず動かすことが可能です.だいたい200イテレーションほどでテストデータセットに対して80%以上の精度が出ると思います.

use miniature::datasets::MNISTLoader;
use miniature::functions as F;
use miniature::graph::backward;
use miniature::optimizers as S;
use miniature::parametric_functions as PF;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let dataset = MNISTLoader::new("datasets")?;
    let (test_x, test_t) = dataset.get_test_data();

    let fc1 = PF::linear(28 * 28, 256);
    let fc2 = PF::linear(256, 10);

    let mut optim = S::adam(0.001, (0.9, 0.999), 1e-8);
    optim.set_params(fc1.get_params());
    optim.set_params(fc2.get_params());

    let mut iter = 0;
    loop {
        let (x, t) = dataset.sample(32);
        let onehot_t = F::onehot(t, 10);

        // forward
        let h = F::relu(fc1.call(x));
        let output = fc2.call(h);

        // loss
        let loss = F::cross_entropy_loss(output, onehot_t);

        optim.zero_grad();
        backward(loss);
        optim.update();

        iter += 1;
        if iter % 100 == 0 {
            // test
            let h = F::relu(fc1.call(test_x.clone()));
            let output = F::argmax(fc2.call(h));

            let mut count = 0;
            let test_size = output.borrow().shape[0];
            for i in 0..test_size as usize {
                let pred_label = output.borrow().data[i] as u8;
                let test_label = test_t.borrow().data[i] as u8;
                if pred_label == test_label {
                    count += 1;
                }
            }
            let accuracy = (count as f32) / (test_size as f32);
            println!("Iteration {}: Accuracy={}", iter, accuracy);
        }

        if iter == 100000 {
            break;
        }
    }

    Ok(())
}

ディープラーニングライブラリとして足りないところ

以上の説明で結構作り込まれている印象を受ける方がいるかもしれませんが,miniatureは以下の点でかなり実装をサボっています.

  • いくつかのfunctionは特定のケースでしか動かないように実装されている
    • 例1: F::broadcastは最初の次元(主にバッチ数)のみブロードキャスト可能
    • 例2: F::argmaxは2次元マトリクスの列のargmaxしか計算しない
  • backwardのテストを一切書いていない
    • 本来は数値微分の結果を真値にして自動微分のテストを行います

また,PyTorchやTensorFlowのような実用的なものに対しては,当然ですが機能の豊富さ以外で実用的であるために必要なものが足りません.

CUDAのサポート

現在はCPUかつシングルスレッドでしか処理していないため,ImageNetデータセットの学習や自然言語処理などは途方もない時間がかかります.少し調べてみるとCUDAのカーネルコードをRustから使うこともできるようなので,CUDAのサポートはそこまで難しくなさそうです.

メモリの効率的な管理

CUDAのサポートと同じくらい重要な要素として,メモリリソースの効率的な管理を行う必要があります.ディープラーニングでは大きな行列を頻繁に作る必要があるので,いちいちヒープから取ってくると非常に効率が悪いです.

そのため,PyTorchやTensorFlow, nnablaでは裏側で事前に確保しておいたメモリを割り当てます.メモリを開放するときは,実際にreleaseは行わずに裏で取っておいて必要な時に再び割り当てます.GPUで計算を行う場合はメモリのallocationが非常に時間がかかるので特に重要な仕組みになります.

Rust初心者なのでわからないところ

Rc<RefCell<T>> 使いすぎ?

Rustのまだ慣れていないため,C++やPythonで実装するのと同じ感覚で全体を設計してしまっています.そのためRc<RefCell<T>>を随所で使っています.ただ,これを多用しすぎるとRustの魅力の一つであるコンパイル時のチェックの恩恵が受けられません.実際にコンパイルは通るが実行時にborrow_mut()でエラーが出ることもありました.

本当はRustの言語仕様に沿った,Rustらしいディープラーニングライブラリの設計というものがある気がします.

traitを隠すためにラッパーを作るのは一般的?

今回はユーザ側が最小限のコードで利用できるようにするために,traitのためだけにuseを使わなくてもいいようにしてあります.ただ,他のRustコードの知見が非常に少ないので,これが一般的にやられていることなのか,それともより良い方法があるのか気になります.

宣伝

ここまで読んでいただきありがとうございました!

miniatureにもし興味があったら是非コードも見てみてください.よりRustっぽい書き方があればプルリクなどで送っていただけると大変嬉しいです.

宣伝ですが,2020年度未踏IT人材発掘・育成事業で開発したオフライン強化学習ライブラリをNeurIPS 2021 Offline RL Workshopで発表できることになりました.もしご興味があればこちらも是非ご覧ください.
- リポジトリ: https://github.com/takuseno/d3rlpy
- 論文: https://arxiv.org/abs/2111.03788

20
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
18