55
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

RustAdvent Calendar 2017

Day 21

primitiv-rustでディープラーニングする

Last updated at Posted at 2017-12-21

本記事はRust Advent Calendar 2017の12月21日の記事です。

RustでディープラーニングしたくてRustでディープラーニングできるようにしたので紹介します。

はじめに

昨今ディープラーニングのフレームワークが勃興し高機能化が進む一方で、Theanoのように一時代を築いたフレームワークの開発停止がアナウンスされるなど1、移り変わりが激しい状況です。
多くのフレームワークはPythonで開発、あるいは公式にPython APIの提供をしており、機械学習・ディープラーニングのライブラリはPythonを中心に利用できるようになっているのですが、諸々の事情でPythonでもなくC++でもなくRustでディープラーニングがしたいというニーズがあるかと思います。

ディープラーニング in Rust

私の知る限りではRust界隈のディープラーニングのフレームワーク事情は以下のとおりです。

Leafはフレームワークのコアの部分までRustで書かれていたのですが、当時はRustでCPU/GPUで高速に行列演算できるライブラリがなく2、ディープラーニングのコアの部分(計算グラフの構築, 微分計算の記述, backward/update)自体よりもバックエンドの行列演算のライブラリ利用・開発に難儀していたようです。3 4
そのような状況でTensorFlowのRustバインディングが公開され、Autumnの開発者がLeafの開発を停止したのですが5、TensorFlowのRustバインディングは学習済みのモデルをloadしてinferenceに使用できるだけというもので、実際にはPythonなど他の言語でモデルの開発・学習・保存する必要があり、Leafの開発の中止は非常に残念でした。6
そこで個人的にDyNetのRustバインディングを試作してある程度動く状況となったのですが7、本格的にバインディングを作るにはまずFFIで関数呼び出しをするためのC APIの開発が必要で8、それにはコアの仕様・実装の理解から始めて膨大な作業時間を要するため、現在は方針を検討中です。9

そのようななか**primitiv**というC++で開発された新興のフレームワークがオープンソースとして公開されました。
個人的な意見ですがprimitivはニューラルネットワークが非常に記述しやすく、フレームワークの特徴が自分のニーズにマッチしており、興味を持ちました。
開発の中心となっている @odashi_t さんと同じ研究コミュニティ・大学院に所属していてコミュニケーションが取りやすい状況であったため、primitivの開発に加わることになりました。
現在は私の担当としてC APIとRustバインディングの開発を進めています。

primitivとは

primitivはもともと NICT (情報通信研究機構) で開発され、オープンソースとして公開されて以降はNAISTの音声・自然言語処理の研究グループを中心に開発が進められています。10

primitivの特徴は、

  • 行列演算のバックエンドの柔軟な切替
  • Define by Run方式の動的な計算グラフの構築

が挙げられ、ざっくりと言うとDyNetのスピンオフのような位置付けです。

これまでChainerやDyNetを使ってきた私の感想として、インターフェースとしては各モジュールの構成はChainerに近く、数式をTheanoのように直感的に記述できるところが優れた点だと思います。
行列演算のバックエンドは、依存ライブラリなしのナイーブなCPUでの演算、Eigenを使ったCPUでの高速な演算、NVIDIAのCUDAを使ったGPUでの演算、OpenCLを使ったAMDやIntelのGPUでの演算が切替可能です。
計算グラフはDefine by Runによって構築されますが、実際のforward計算は明示的に計算結果を得る操作を呼び出すまでは行われません。
この遅延評価方式はフレームワーク側で計算が最適化しやすくなるという利点があります。

インストール方法

primitiv Rustバインディングのインストール方法について説明します。
なお下記手順およびコードは本記事執筆時点のものになります。

# primitiv coreのインストール
$ git clone https://github.com/primitiv/primitiv/
$ cd primitiv
$ git checkout dc022e3fd4c343f7b46f2d04698f940211bca773
$ mkdir build
$ cd build
$ cmake .. -DPRIMITIV_BUILD_C_API=ON
$ make
$ make install

# primitiv-rustのビルド
$ cd ../../
$ git clone https://github.com/primitiv/primitiv-rust/
$ git checkout 0cea052615172e71695d20c80e480fc91b78bc2c
$ cd primitiv-rust
$ cargo build

実装例

MNISTを例に実装方法を簡単に説明します。

モジュールのimport、データの読み込み

今回のサンプルコードで使用するモジュールとデータを読み込みます。

// モジュールのimport
use primitiv::device;
use primitiv::Graph;
use primitiv::Optimizer;
use primitiv::Parameter;
use primitiv::Shape;

use primitiv::devices as D;
use primitiv::functions as F;
use primitiv::initializers as I;
use primitiv::optimizers as O;

// サンプルに使用する定数の定義
const NUM_TRAIN_SAMPLES: u32 = 60000;
const NUM_TEST_SAMPLES: u32 = 10000;
const NUM_INPUT_UNITS: u32 = 28 * 28;
const NUM_HIDDEN_UNITS: u32 = 800;
const NUM_OUTPUT_UNITS: u32 = 10;
const BATCH_SIZE: u32 = 200;
const NUM_TRAIN_BATCHES: u32 = NUM_TRAIN_SAMPLES / BATCH_SIZE;
const NUM_TEST_BATCHES: u32 = NUM_TEST_SAMPLES / BATCH_SIZE;
const MAX_EPOCH: u32 = 100;

...

fn main() {
    // データの読み込み
    let train_inputs = load_images("data/train-images-idx3-ubyte", NUM_TRAIN_SAMPLES);
    let train_labels = load_labels("data/train-labels-idx1-ubyte", NUM_TRAIN_SAMPLES);
    let test_inputs = load_images("data/t10k-images-idx3-ubyte", NUM_TEST_SAMPLES);
    let test_labels = load_labels("data/t10k-labels-idx1-ubyte", NUM_TEST_SAMPLES);

デバイスの登録

primitivでははじめにDeviceのオブジェクトを生成し、device::set_default()によってデフォルトのデバイスとして登録することで、以降の計算をデフォルトデバイス上で行います。

    // デフォルトデバイスの登録
    let mut dev = D::Naive::new();
    device::set_default(&mut dev);

後述しますが、パラメータの初期化や計算グラフの入力ノードの指定に明示的にデバイスを指定することで、デフォルト以外のデバイス上で演算ができ、複数のデバイスを使った処理も可能です。

パラメータの定義・初期化、Optimizerへの登録

パラメータの初期化には明示的にInitializerのオブジェクトを指定することで行います。11

    // パラメータの定義・初期化
    let mut pw1 = Parameter::from_initializer([NUM_HIDDEN_UNITS, NUM_INPUT_UNITS], &I::XavierUniform::new(1.0));
    let mut pb1 = Parameter::from_initializer([NUM_HIDDEN_UNITS], &I::Constant::new(0.0));
    let mut pw2 = Parameter::from_initializer([NUM_OUTPUT_UNITS, NUM_HIDDEN_UNITS], &I::XavierUniform::new(1.0));
    let mut pb2 = Parameter::from_initializer([NUM_OUTPUT_UNITS], &I::Constant::new(0.0));

    // Optimizerへの登録
    let mut optimizer = O::SGD::new(0.5);
    optimizer.add_parameter(&mut pw1);
    optimizer.add_parameter(&mut pb1);
    optimizer.add_parameter(&mut pw2);
    optimizer.add_parameter(&mut pb2);

グラフの構築

primitivではGraphオブジェクトを生成し、計算グラフの構築に使用します。
Graphオブジェクトは計算グラフを構築するための関数を呼び出す際に引数として指定しますが、Graph::set_default()によってデフォルトのグラフとして登録しておくことで、関数呼び出しの際は省略することができます。12

    let mut g = Graph::new();
    Graph::set_default(&mut g);

グラフの構築はfunctions (サンプルコード中ではFというエイリアス)に定義している関数を使って行います。
まず、functions::input()functions::parameter() を用いて入力データ、パラメータからそれぞれNode オブジェクトを生成します。
Nodeは計算結果を示すオブジェクトで、Node同士で直接算術演算を行ったり、functionsモジュールで定義されている関数に引数として渡したりすることで任意の演算を表現することができます。
Nodeに対する演算は指定したGraphに自動的に記録されるため、数式に沿った演算を適用していくだけで計算グラフを構築することができます。

    // 入力のスライスを受け取って`Node`を返すクロージャを定義
    let mut make_graph = |inputs: &[f32], train| {
        let x = F::input(([NUM_INPUT_UNITS], BATCH_SIZE), &inputs);
        let w1 = F::parameter(&mut pw1);
        let b1 = F::parameter(&mut pb1);
        let h = F::relu(F::matmul(w1, x) + b1);
        let h = F::dropout(h, 0.5, train);
        let w2 = F::parameter(&mut pw2);
        let b2 = F::parameter(&mut pb2);
        F::matmul(w2, h) + b2
    };

学習ループ (forward, backward, update) の実行

計算グラフの初期化、グラフの構築、forward演算、backward演算、パラメータの更新は次のようになります。

g.clear();  // 計算グラフの初期化

let y = make_graph(&inputs, true);  // グラフの構築 (デフォルトの`Graph`に`Node`を追加)
let loss = F::softmax_cross_entropy(y, &labels, 0);  // ロスの`Node`を追加
let avg_loss = F::batch::mean(loss);

let loss_val = avg_loss.to_float();  // 計算結果を取得
println!("  loss: {}", loss_val);

optimizer.reset_gradients();  // 勾配のリセット
avg_loss.backward();  // backward演算
optimizer.update();  // パラメータの更新

Define by Run方式のChainerやPyTorchと大きく異なる点は、ChainerやPyTorchが計算グラフ上のノードの追加の際にforwardの計算結果を即時に演算するのに対し、primitivやDyNetではforwardの計算グラフの構築時に値の計算を行わず、具体的な計算結果が必要になったとき (上記の例ではto_float()の呼び出し時) に、初めて値の計算を行うという遅延評価を採用しているところです。

学習ループ全体は下記のようになります。

    let mut rng = thread_rng();
    let mut ids: Vec<usize> = (0usize..NUM_TRAIN_SAMPLES as usize).collect();

    for epoch in 0..MAX_EPOCH {
        rng.shuffle(&mut ids);  // データのインデックスのシャッフル

        for batch in 0..NUM_TRAIN_BATCHES {
            print!("\rTraining... {} / {}", batch + 1, NUM_TRAIN_BATCHES);
            // ミニバッチのデータの取り出し
            let mut inputs: Vec<f32> = Vec::with_capacity((BATCH_SIZE * NUM_INPUT_UNITS) as usize);
            let mut labels: Vec<u32> = vec![0; BATCH_SIZE as usize];
            for i in 0..BATCH_SIZE {
                let id = ids[(i + batch * BATCH_SIZE) as usize];
                let from = id * NUM_INPUT_UNITS as usize;
                let to = (id + 1) * NUM_INPUT_UNITS as usize;
                inputs.extend_from_slice(&train_inputs[from..to]);
                labels[i as usize] = train_labels[id] as u32;
            }

            // 計算グラフの初期化
            g.clear();

            // 計算グラフの構築
            let y = make_graph(&inputs, true);
            let loss = F::softmax_cross_entropy(y, &labels, 0);
            let avg_loss = F::batch::mean(loss);

            // 勾配の初期化, backward, パラメータの更新 ※トレーニング時のみ
            optimizer.reset_gradients();
            avg_loss.backward();
            optimizer.update();
        }

        println!();

        let mut match_ = 0;

        for batch in 0..NUM_TEST_BATCHES {
            print!("\rTesting... {} / {}", batch + 1, NUM_TEST_BATCHES);
            // 評価データの取り出し
            let mut inputs: Vec<f32> = Vec::with_capacity((BATCH_SIZE * NUM_INPUT_UNITS) as usize);
            let from = (batch * BATCH_SIZE * NUM_INPUT_UNITS) as usize;
            let to = ((batch + 1) * BATCH_SIZE * NUM_INPUT_UNITS) as usize;
            inputs.extend_from_slice(&test_inputs[from..to]);

            // 計算グラフの初期化
            g.clear();

            // 計算グラフの構築
            let y = make_graph(&inputs, false);

            // 計算結果の算出
            let y_val = y.to_vector();
            // 評価
            for i in 0..BATCH_SIZE {
                let mut maxval = -1e10;
                let mut argmax: i32 = -1;
                for j in 0..NUM_OUTPUT_UNITS {
                    let v = y_val[(j + i * NUM_OUTPUT_UNITS) as usize];
                    if v > maxval {
                        maxval = v;
                        argmax = j as i32;
                    }
                }
                if argmax == test_labels[(i + batch * BATCH_SIZE) as usize] as i32 {
                    match_ += 1;
                }
            }
        }

        let accuracy = 100.0 * match_ as f32 / NUM_TEST_SAMPLES as f32;
        println!("\nepoch {}: accuracy: {:.2}%", epoch, accuracy);
    }

ソースコード全体は下記に置いています。

primitiv-rust/mnist.rs - GitHub

GPU上での実行

CUDAやOpenCLを使ったGPU上での演算は、使用するデバイスを切り替えるだけで可能になります。

    let mut dev = D::CUDA::new(0);
    // let mut dev = D::Naive::new();
    // let mut dev = D::OpenCL::new(0, 1);
    // let mut dev = D::Eigen::new();
    device::set_default(&mut dev);

複数デバイスを使用した計算グラフの構築・演算

Parameterの生成や入力に使用するNodeの構築時に明示的にデバイスを指定することで、複数のデバイスを使用することができます。

let mut dev0 = D::CUDA::new(0);
let mut dev1 = D::CUDA::new(1);

let mut pw1 = Parameter::from_initializer_with_device(
    [NUM_HIDDEN_UNITS, NUM_INPUT_UNITS],
    &I::XavierUniform::new(1.0),
    Some(&mut dev0),
);
let mut pw2 = Parameter::from_initializer_with_device(
    [NUM_HIDDEN_UNITS, NUM_INPUT_UNITS],
    &I::XavierUniform::new(1.0),
    Some(&mut dev1),
);

...

let x1 = F::input_with_device(([NUM_INPUT_UNITS], BATCH_SIZE), &inputs1, Some(&mut dev0));
let x2 = F::input_with_device(([NUM_INPUT_UNITS], BATCH_SIZE), &inputs2, Some(&mut dev1));

Future Work

primitivのコアの機能の開発としては下記を予定しています。

  • マルチスレッドでの演算の対応
  • Define by Run + 遅延評価 を利用したミニバッチ計算の最適化 (autobatch)
  • 高階微分
  • Convolution, Pooling演算の実装 13

RustバインディングはC APIがunstableなため開発が遅れていますが、これから順次対応を進めていきます。
(インストール方法やドキュメントの整備、テストスクリプトの作成等いろいろ間に合わず、大変申し訳ないですが興味を持った人がすぐに試せる状況に至っていません。)

他の言語のバインディングとしてはPythonでprimitiv version 0.3の機能が使えるようになっています。
これからJavaのバインディングを開発し、KotlinやScalaからも使えるように進めていく予定です。

まとめ

本記事ではディープラーニングのフレームワーク primitiv のRustバインディングの概要と、Rustバインディングを使ったモデルの構築・学習方法について紹介しました。
primitivは異なるプラットフォームのデバイスを使い分けできる遅延評価型のDefine by Run方式の計算グラフ構築・演算を特徴とする新興のフレームワークです。
後発のフレームワークですが既存のフレームワークの良いところを取り入れ、正式なリリースに向けて機能の拡充をしています。

primitiv developerチームは開発に協力してくれる人を募集しています。
興味のある方は @odashi_t さん、または @chantera までご連絡ください。

追記

2017年12月25日

インストール方法について記載しました。
https://qiita.com/odashi_t/items/4b90efc20b09c30e58bf

Leafの後継 Juice について記載しました。
@fx-kirin さん、情報提供ありがとうございます。

@odashi_t さんがprimitivの解説記事を書いてくれました。
primitiv: 新しいニューラルネットワークのライブラリ


  1. MILA and the future of Theano

  2. 現在ではCPUでの行列演算のライブラリとしてrust-ndarrayなどが利用可能。

  3. hobofan 771 days ago, 記事: Leaf: machine intelligence framework in Rust | Hacker News

  4. Collenchyma: CUDA, OpenCL and Native Machine Intelligence for Leaf : rust

  5. Tensorflow wins – Michael Hirn – Medium

  6. Leaf の後継もあるらしいです。 spearow/juice - GitHub

  7. dynet-rs/bilstmtagger.rs - GitHub

  8. rust-bindgenのC++対応が完全ではないため・

  9. https://github.com/chantera/dynet-rs/issues/1

  10. NAISTは奈良の山奥に位置していて娯楽がないので、ディープラーニングのフレームワークを作るぐらいしかやることがないという人が割といます。

  11. ディープラーニングではパラメータの初期化は学習結果に大きく影響を及ぼす要因ですが、フレームワーク側でデフォルトの初期化方法を決めてしまうとパラメータの初期化が開発者に軽視されるだけでなく、フレームワークのソースコードを読んで仕様を把握する必要がある(最悪の場合はデフォルトの初期化方法が変更されてしまう)ため、個人的には現状のprimitivのようにInitializerを指定するインターフェースが良いと思います。

  12. 関数呼び出しの際に明示的にグラフを渡すような状況として、複数のグラフを使ったマルチスレッドでの計算グラフの構築・演算を想定しています(マルチスレッドでの計算グラフの構築・演算の機能は現在開発中です)。

  13. 開発陣が自然言語処理関係の専門の人が多数のため、開発の優先度が低いです。画像関係でディープラーニングに詳しい開発者・研究者からのご意見・ご協力をお願いします。

55
39
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?