LoginSignup
3
4

More than 5 years have passed since last update.

tiny-dnnを用いたIrisデータセットの学習

Posted at

C++の深層学習用フレームワークであるtiny-dnnを用いて、Pythonのsklearnで提供されているirisのデータセットを学習させました。

まとめ

  • tiny-dnnを用いてsklearnで提供されているirisのデータセットを学習させるコードを実装。
    • コードはGitHubで公開。
  • tiny-dnnはKerasのような感覚で使いやすい。

環境

  • Windows 10 Pro (64 bit) Version 1607
  • Intel Core i5-3317U (1.70 GHz×2), RAM 4 GB
  • Visual Studio 2017 Community
  • tiny-dnnとsklearnを利用

背景

これまでPythonをよく利用してきたこともあり、Tensowflow, sklearn, Kerasあたりを使ってきました。
とある事情によりc++で、しかもx86環境で深層学習を扱う必要が生じました。
c++での深層学習フレームワークを探したところ、Microsoft社のCNTK, Cで書かれたdarknet, そしてtiny-dnnあたりがよさそうだということが分かりました。
それぞれについての検討状況は次の通りであり、今回はtiny-dnnを利用することにしました。

  • CNTK:2017/06/18 (Sun.) 現在、Visual Studio 2015 (c++14)のx64環境でしか動かないようである。
    • VS2017でもv140のツールセットをインストールすれば動く?いずれにせよx86では動かない模様。
  • tiny-dnn:x86でも動作する。
  • darknet:tiny-dnnがx86環境で動作することが判明したため未検討。

tiny-dnnについて

tiny-dnnはC++のx86/x64環境で利用できる深層学習フレームワークの一つです。
詳細や導入方法はOriginal authorの方に紹介されていますのでここでは割愛させていただきます。
ほかには実装ノート(バージョンが古いようですが、分かりやすいです。)やDocumentationが参考になります。
Qiitaでは『FPGAでDeep Learningしてみる』という報告があり、こちらではBNN-PYNQというtiny-dnnを利用したプロジェクトが利用されています。

コードの実装

コードはGitHubで公開しています。
2017/06/18 (Sun.) 時点での実装ですが、関数main.cpp\main(void)内で関数dnn_iris(size_t batch_size, size_t epoch)が実行されています。
この関数の内容は次の通りです。

  1. 関数load_iris_vec_tでirisのデータセット(プロジェクト内では.\\data\\iris.csv)を読み込み、特徴量Xとラベルyに分ける。
  2. 関数split_train_dataで、読み込んだirisのデータセットを学習データとテストデータに分ける。
  3. モデルを定義する。今回は入力層、中間層、出力層の三層パーセプトロンとした。

    network<sequential> model;
    size_t n_neural = 100;
    model << fc(X[0].size(), n_neural) << activation::tanh()
          << fc(n_neural, y[0].size()) << activation::softmax();
    
  4. モデルを次のように訓練する。

    adagrad opt;
    model.train<cross_entropy_multiclass>(opt, X_train, y_train, 
    batch_size, epoch);
    
  5. 訓練されたモデルを用いてテストデータを分類する。そのときのlossとaccuracyを計算する。

テストデータとして読み込ませているiris.csvには、sklearn.datasets.load_irisで得られるデータセットのXyが次の形式で保存されています。

// X_k (k=1,2,3,4): 特徴量
// y: ラベル
X_1, X_2, X_3, X_4, y
5.1, 3.5, 1.4, 0.2, 0
4.9,   3, 1.4, 0.2, 0
4.7, 3.2, 1.3, 0.2, 0
4.6, 3.1, 1.5, 0.2, 0
  5, 3.6, 1.4, 0.2, 0
5.4, 3.9, 1.7, 0.4, 0
...

エポック数を変えて学習させた結果

試しにバッチ数を10に固定し、エポック数を10~100で変化させたときのloss、accuracy、訓練時間を調べてみました。その結果を次に示します。

Batch_size, Epoch, Elapsed time of training, Loss, Accuracy
10,  10,  3.95 sec, 92.2440, 0.6944
10,  20,  7.03 sec, 65.5258, 0.6944
10,  30, 10.86 sec, 62.9554, 0.7500
10,  40, 14.84 sec, 51.8333, 0.8889
10,  50, 18.56 sec, 52.2328, 0.8889
10,  60, 22.28 sec, 47.6374, 0.9167
10,  70, 27.04 sec, 46.4205, 0.9167
10,  80, 30.26 sec, 43.7856, 0.9167
10,  90, 35.74 sec, 36.4594, 0.9167
10, 100, 38.84 sec, 34.6639, 0.9444

今後やってみたいこと

  • CSVファイルを読み込む関数を用意しているので、天気予報のデータや株価などほかのデータも学習させてみる。
  • modelの保存・読み込みを行う関数が提供されているのでそれを使ってみる。
  • x86のDLLにして外部から利用できるようにしてみる。
  • などなど

参考サイト

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4