LoginSignup
1
0

More than 3 years have passed since last update.

TFS_DNNでCIFAR10を作ってみた

Last updated at Posted at 2020-03-27

組み込み用のDNNが欲しくて性能ではなく、とにかくC++辺りでポータブルに出来ないかと思っていろいろ探してみた。iPhoneでも使えるDeepBeliefSDKとかもあるが、これでも私にとってはちょっと規模がでかいと感じる。そこで、いろいろ探しまわったら丁度良いのが見つかった。とにかくコード的に軽量で手軽なものを探した。結果としては、tfs_dnnというのを見つけた。(多分、2016/11現在、日本人でこれを触ったのは私が初めてだろう)
https://github.com/barrettd/tfs_dnn

こいつを選んだもうひとつの理由は、convnetjsにインスパイヤーされて作られているということだ。
http://cs.stanford.edu/people/karpathy/convnetjs/
JavaScriptベースのDNNで、無論速度の期待は一切出来ない。しかし、可視性が良く学習過程が非常に良くわかる。DNNの学習には持ってこいという感じがする。tfs_dnnは、このコードを多分に用いているため、convnetjsである程度試した結果を、C++で実現する流れができる。これらの作業の中に、CUDAなんつう高級なものは一切存在しない。他の高機能なDNNには当然勝てないが、なんせ見通しがいい。この後の作業としては、FPGAに高位合成でDNNチップを作ろうという流れである。(なるべくベタなC++実装の方が、こういうときには都合が良いのだ)

サンプルコードとしてCIFAR10のネットワークを組んで見た。(学習データは流し込んでいないので何も出来ないが、まぁ、一応動いてるぽいことは確認できる。
以下、そのテストコードの断片。定数とか全くdefineもしていないが、大変ではないのでなんとなく分かるでしょう。

#include "dqn.h"

bool writeDnn( Dnn* dnn, const char *filename ) {
    OutDnnStream outStream( filename );
    const bool rc = outStream.writeDnn( *dnn );
    if( !rc ) {
        log_error( "DNN write failed" );
    }
    outStream.close();
    return rc;
}

bool readDnn( Dnn* &dnn, const char *filename ) {
    InDnnStream inStream( filename );
    dnn = inStream.readDnn( true );
    inStream.close();
    if( dnn == 0 ) {
        return log_error( "Unable to read a DNN" );
    }
    return true;
}

bool setupDnn( Dnn &dnn ) {
    /*
    DnnBuilder builder( dnn, ACTIVATION_RELU );
    builder.addLayerInput( 1, 1, 2 );
    builder.addLayerFullyConnected( 20 );
    builder.addLayerSoftmax( 2 );
    dnn.initialize();
    const unsigned long expected_count = 5;
    const unsigned long count = dnn.count();
    if( count != expected_count ) {
        return log_error( "We have set up %lu layers, expected %lu", count, expected_count );
    }
    return true;
    */

    //layer_defs = [];
    //layer_defs.push({type:'input', out_sx:32, out_sy:32, out_depth:3});
    //layer_defs.push({type:'conv', sx:5, filters:16, stride:1, pad:2, activation:'relu'});
    //layer_defs.push({type:'pool', sx:2, stride:2});
    //layer_defs.push({type:'conv', sx:5, filters:20, stride:1, pad:2, activation:'relu'});
    //layer_defs.push({type:'pool', sx:2, stride:2});
    //layer_defs.push({type:'conv', sx:5, filters:20, stride:1, pad:2, activation:'relu'});
    //layer_defs.push({type:'pool', sx:2, stride:2});
    //layer_defs.push({type:'softmax', num_classes:10});
    //net = new convnetjs.Net();
    DnnBuilder builder( dnn, ACTIVATION_RELU );
    dnn.addLayerInput( 32, 32, 3 );             // Input layer for 32x32 RGB image
    // Convolution / Activation / Pool set:
    dnn.addLayerConvolution( 5, 16, 1, 2 );     // 16 5x5 filters for convolution
    dnn.addLayerRectifiedLinearUnit();          // Activation function for previous layer.
    dnn.addLayerPool( 2, 2 );
    // Convolution / Activation / Pool set:
    dnn.addLayerConvolution( 5, 20, 1, 2 );     // 20 5x5 filters for convolution
    dnn.addLayerRectifiedLinearUnit();          // Activation function for previous layer.
    dnn.addLayerPool( 2, 2 );
    // Convolution / Activation / Pool set:
    dnn.addLayerConvolution( 5, 20, 1, 2 );     // 20 5x5 filters for convolution
    dnn.addLayerRectifiedLinearUnit();          // Activation function for previous layer.
    dnn.addLayerPool( 2, 2 );
    // Fully connected layer with softmax:
    dnn.addLayerFullyConnected( 10 );               //  1, 1, 10
    dnn.addLayerSoftmax();                          // Output classifier
    dnn.initialize();
    const unsigned long count = dnn.count();
    if( count != 12 ) {
        return log_error( "Expected 12 layers, received %lu", count );
    }
    return true;
}

bool trainDnn( DnnTrainer* &trainer, Dnn &dnn ){
    DnnTrainerSGD* trainer_sgd = new DnnTrainerSGD( &dnn );
    //trainer_sgd -> learningRate( 0.01  );
    //trainer_sgd -> l2Decay( 0.001 );

    trainer_sgd -> learningRate( 0.0001 );
    trainer_sgd -> momentum(     0.0    );
    trainer_sgd -> batchSize(    1      );
    trainer_sgd -> l2Decay(      0.0    );

    trainer = trainer_sgd;
}
#include <stdio.h>
#include "unittest_dqn.h"

#include "dqn.h"

bool unittest_dqn(){
    puts("#unittest_dqn");

    ////DNN0
    puts("##DNN0");
    {
        //setup
        Dnn* dnn = new Dnn();
        setupDnn( *dnn );

        //init
        {
            //1D
            //DNN_NUMERIC *input  = dnn->getDataInput();
            //DNN_NUMERIC *output = dnn->getDataOutput();

            //2D
            Matrix *input  = dnn -> getMatrixInput();
            DNN_NUMERIC *output = dnn->getDataOutput();
            DNN_NUMERIC *data = input->data();

            //input
            for(int i=0;i<32*32*3;i++){
                *data = 0.56;
                data++;
            }

            if( !dnn->predict()) {
                log_error( "Problem with forward propagation through the network." );
            }

            //output
            puts("before train");
            for(int i=0;i<10;i++){
                DNN_NUMERIC probability = output[i];
                printf("%d:%.17g
",i,probability);
            }
            puts("");
        }

        //train
        {
            for(int i=0;i<200;i++){
                printf("epoch:%d
",i);

                //2D
                Matrix *input  = dnn -> getMatrixInput();
                //DNN_NUMERIC *output = dnn->getDataOutput();
                DNN_NUMERIC *data = input->data();

                //input
                for(int i=0;i<32*32*3;i++){
                    *data = 0.56;
                    data++;
                }
                //output
                //DNN_INTEGER outut[10];
                //for(int i=0;i<10;i++){
                //    if( i==7 ){
                //        outptu[i]=1;
                //    }else{
                //        output[i]=0;
                //    }
                //}

                //train
                DnnTrainer* trainer;
                trainDnn( trainer, *dnn );
                for(int i=0;i<10;i++){
                    trainer->train( 4 );
                }
                delete trainer;
            }
        }

        //save
        {
            writeDnn( dnn, "/work/dqn.dnn" );
        }

        //test
        {
            //1D
            //DNN_NUMERIC *input  = dnn->getDataInput();
            //DNN_NUMERIC *output = dnn->getDataOutput();

            //2D
            Matrix *input  = dnn -> getMatrixInput();
            DNN_NUMERIC *output = dnn->getDataOutput();
            DNN_NUMERIC *data = input->data();

           //input
            for(int i=0;i<32*32*3;i++){
                *data = 0.56;
                data++;
            }

            if( !dnn->predict()) {
                log_error( "Problem with forward propagation through the network." );
            }

            //output
            puts("after train");
            for(int i=0;i<10;i++){
                DNN_NUMERIC probability = output[i];
                printf("%d:%.17g
",i,probability);
            }
            puts("");
        }

        delete dnn;
    }


    ////DNN1
    puts("##DNN1");
    {
        //setup
        Dnn* dnn;
        readDnn( dnn, "/work/dqn.dnn" );

        //test
        {
            //1D
            //DNN_NUMERIC *input  = dnn->getDataInput();
            //DNN_NUMERIC *output = dnn->getDataOutput();

            //2D
            Matrix *input  = dnn -> getMatrixInput();
            DNN_NUMERIC *output = dnn->getDataOutput();
            DNN_NUMERIC *data = input->data();

           //input
            for(int i=0;i<32*32*3;i++){
                *data = 0.56;
                data++;
            }

            if( !dnn->predict()) {
                log_error( "Problem with forward propagation through the network." );
            }

            //output
            puts("other dnn");
            for(int i=0;i<10;i++){
                DNN_NUMERIC probability = output[i];
                printf("%d:%.17g
",i,probability);
            }
            puts("");
        }

        delete dnn;
    }

    return true;
}

やっていることを説明すると、適当な画像(32*32*RGB)を作って(ここでは適当な定数0.56とかで埋めているだけ)、適当なクラスに分類させる(ここでは10クラスで4番目)。これを適当なエポック数回して、結果を表示している。で、結果を表示してみると、4番目に分類されていることがわかる。ここまで動けば、あとは実際のデータを流すだけになるわけだ。この間にネットワークのセーブとロードもしている。これが簡単だったのも、tfs_dnnの良かった点だ。

以下、実行結果。初期時は、10クラスの分類が出来ない状態なので、ほぼ全ての確率が、0.1程度になっている。要するに頭が空っぽで区別がついていない。

before train
0:0.10000918021328996
1:0.098242792002857668
2:0.10447675242618829
3:0.099182862821362683
4:0.095431864564279781
5:0.10084649519781451
6:0.095886253996930959
7:0.10136782841270547
8:0.10467869788694967
9:0.099877272477620987

以下、200epoch回したあとの結果。相当確信を持って、「4」の分類が出来ている。

after train
0:0.0048271780099294285
1:0.0047157796022943496
2:0.0058145865112718608
3:0.0041484372577943324
4:0.95417696648567618
5:0.0053042266069439504
6:0.0042285995151014229
7:0.0053715219628276412
8:0.0064264998998294675
9:0.0049862041483313317

名前は、意図的にDQNとしている。強化学習用の母体のコードにするつもり。ここから拡張をしてDQN-FPGAチップを作ろうかなと思っている。学習を行う文字通りニューロチップができるわけだ。ロボットの頭脳に採用するつもり。このサンプルは、DQNのオリジナル論文https://arxiv.org/pdf/1312.5602.pdf のネットワーク構造が、とても良く似ていたので、フィルタのサイズだけ変えれば、ATARIゲームのPingPongくらいはこなせる人工知能と同等になるだろう。

これだけではまだまだなので、ルールベース人工知能とドッキングさせるかな。。。

この実装でCIFAR10は動くはずだけど、私のしょぼいマシンでは、どれくらいの学習時間がかかるのやら。。。それもあるんで、こんなことしか出来ないんです。申し訳ない・・・。追実装なので出来ることは分かっているのだが。。。涙

1
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0