Help us understand the problem. What is going on with this article?

ChainerX うごかしてみるよ

More than 1 year has passed since last update.

本日は

カレンダー3日目埋まってないので引き続き書いていこうと思います。
ChainerXの導入を昨日書きましたが、今回は実際に動かしてみましょう

MNIST 動きました。

カレンダー2日目の記事のようにMNISTは動きました。
実際に動かしてみるとわかるんですけれど通常のMNISTの例とChainerX側のExampleのMNISTを同時に走らせるとその速さが実感できると思います。

ChainerXの train_mnist.py では13秒ほどで終了します。

ResNet50 の方はどうでしょう?

MNIST の他にも
https://github.com/chainer/chainer/tree/master/chainerx_cc/examples/imagenet
にてImageNetの学習用スクリプトも用意されています。

1000クラスのデータ落とすと気が遠くなりそうなので Caltech101 を用いた https://github.com/terasakisatoshi/chainer-caltech-101 を使うことにします。

データの準備

iDeep を使ってCPUでのChainerの推論速度をアップしよう の記事通りに行います。

学習

$ python train_chx_resnet.py train.txt test.txt -j 8 -d cuda:0

そうすると下記のように進捗が出ます。プログレスバーは tqdm をつかっています。

Namespace(batchsize=32, device='cuda:1', epoch=50, iteration=None, loaderjob=8, mean='mean.npy', root='.', test=False, train='train.txt', val='test.txt', val_batchsize=250)
100%|██████████████████████████████████████████████████████████████| 213/213 [00:49<00:00,  4.33it/s]
epoch 1... loss=4.42953271484375,   accuracy=0.112, elapsed_time=51.761133432388306
100%|██████████████████████████████████████████████████████████████| 213/213 [00:47<00:00,  4.53it/s]
epoch 2... loss=4.365013671875, accuracy=0.16,  elapsed_time=99.80267882347107
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 3... loss=4.41390185546875,   accuracy=0.108, elapsed_time=147.61588311195374
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.54it/s]
epoch 4... loss=4.7586708984375,    accuracy=0.1,   elapsed_time=195.54803729057312
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 5... loss=4.062323486328125,  accuracy=0.18,  elapsed_time=243.3883125782013
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 6... loss=4.39565087890625,   accuracy=0.148, elapsed_time=291.20321798324585
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.54it/s]
epoch 7... loss=4.57482275390625,   accuracy=0.136, elapsed_time=339.06728291511536
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 8... loss=4.13157861328125,   accuracy=0.172, elapsed_time=386.8594512939453
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 9... loss=4.64961767578125,   accuracy=0.152, elapsed_time=434.6652834415436
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 10... loss=4.12386865234375,  accuracy=0.196, elapsed_time=482.47300910949707
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 11... loss=4.1064111328125,   accuracy=0.18,  elapsed_time=530.2803256511688
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 12... loss=4.09931884765625,  accuracy=0.192, elapsed_time=578.1029961109161
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 13... loss=3.987043701171875, accuracy=0.184, elapsed_time=625.9076073169708
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 14... loss=4.33284912109375,  accuracy=0.16,  elapsed_time=673.7270267009735
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 15... loss=4.13598095703125,  accuracy=0.164, elapsed_time=721.5430347919464
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
epoch 16... loss=3.932227294921875, accuracy=0.22,  elapsed_time=769.3474335670471
100%|██████████████████████████████████████████████████████████████| 213/213 [00:46<00:00,  4.55it/s]
  • 動かしてわかったんですけれど、Save機能がついていませんでしたOTL... ChainerX用のレイヤー定義されている params のプロパティたちを逐次保存していくことにしていくとできるはずです。
  • あと途中で accuracy がヘタるのと epoch 50 じゃたりなさそう。
    • データが101種類しかないのに出力が1000にしていたからかな???

いちおう通常のResNet50版も用意しておきました。

$ python train.py train.txt test.txt -j 8 --gpu 0 -a resnet50
Namespace(arch='resnet50', batchsize=32, epoch=50, gpu=1, initmodel=None, loaderjob=8, mean='mean.npy', out='result', resume='', root='.', test=False, train='train.txt', val='test.txt', val_batchsize=250)
epoch       iteration   main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  lr        
4           1000        3.33122                           0.29725                                  0.01        
     total [####..............................................]  9.38%
this epoch [##################################................] 69.14%
      1000 iter, 4 epoch / 50 epochs
    4.6415 iters/sec. Estimated time to finish: 0:34:40.738241.

Namespace(arch='resnet50', batchsize=32, epoch=50, gpu=1, initmodel=None, loaderjob=8, mean='mean.npy', out='result', resume='', root='.', test=False, train='train.txt', val='test.txt', val_batchsize=250)
end of build_schedule()
Creating new backward schedule...
end of build_schedule()
epoch       iteration   main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  lr        ch [################################..................] 64.45%
4           1000        3.37801                           0.295188                                 0.01        
9           2000        2.07377                           0.4835                                   0.01        
14          3000        1.43291                           0.617625                                 0.01        
18          4000        0.971574                          0.725125                                 0.01        
23          5000        0.65357                           0.80825                                  0.01        
28          6000        0.426234                          0.868156                                 0.01        
32          7000        0.262041                          0.916875                                 0.01        
37          8000        0.181415                          0.942906                                 0.01        
42          9000        0.129002                          0.959125                                 0.01        
46          10000       0.0772636                         0.977                                    0.01        

お、、、

$ python predict.py -a resnet50 --ideep
total accuracy rate =  0.9647058823529412

50 epoch でよさそうですね。:innocent:

MNISTの場合に比べると学習スピードに大きな速度差は見られません。ResNet50自体大きめのネットワークなので学習時にGPUが常に100%使うことが多いです。そのためChainerXと学習スピードに大きな違いは出ないんだとおもいます。もっと軽いネットワークであればChainerXのちからでGPUの資源を活用しながら学習スピードを向上させることができると思います。

推論方法はどうするんだろうか?

仮にモデルがセーブ出来てロードできる機構があるとすると推論のコードはどうなるんだろう?

学習スクリプトの evaluate を見てみます。

def evaluate(model, X_test, Y_test, eval_size, batch_size):
    N_test = X_test.shape[0] if eval_size is None else eval_size

    if N_test > X_test.shape[0]:
        raise ValueError(
            'Test size can be no larger than {}'.format(X_test.shape[0]))

    model.no_grad()

    # TODO(beam2d): make chx.array(0, dtype=...) work
    total_loss = chx.zeros((), dtype=chx.float32)
    num_correct = chx.zeros((), dtype=chx.int64)
    for i in range(0, N_test, batch_size):
        x = X_test[i:min(i + batch_size, N_test)]
        t = Y_test[i:min(i + batch_size, N_test)]

        y = model(x)
        total_loss += compute_loss(y, t) * batch_size
        num_correct += (y.argmax(axis=1).astype(t.dtype)
                        == t).astype(chx.int32).sum()

    model.require_grad()

    mean_loss = float(total_loss) / N_test
    accuracy = int(num_correct) / N_test
    return mean_loss, accuracy

この実装にならって書くと良いと思います。

model.no_grad() を書いたあとその中で推論作業しているみたいですね。
model.require_grad でパラメータ更新を有効にする流れだと思います。

ここの部分は将来的にコンテキストマネージャーで

with chainer.using_config('train', False):
   # do something

と同様な書き方ができる設計になると予想しています。

以上高級料理 ChainerX クッキングでした。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away