23
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

自作のソースコードをChainer v2 alphaに対応させてみた

Posted at

Chainer v2 alpha

Chainer v2 alphaがリリースされたので、自作のソースコードを対応させてみました。
以下のサイトを参考にしました。

動作環境

  • OS: Windows 10(64bit)
  • Python: Python 2.7.11 :: Anaconda custom (64-bit)
  • GPU: GTX 1080

使用したリポジトリ

以前作成したCIFAR-10の画像認識のリポジトリにChainer v2用のブランチを作成しました。

インストール

Chainer Meetupのスライドにあるように、以下のコマンドでインストールできました。
念のため--no-cache-dirをつけました。

$ pip install chainer --pre --no-cache-dir
$ pip install cupy --no-cache-dir

とりあえず実行してみる

Chainer v2の修正により後方互換性が壊れるので、動作しないことは予想できているのですがとりあえず実行してみます。

$ python src/train.py -g 0 -m vgg -p model\temp9 -b 100 --iter 200 --lr 0.1 --optimizer sgd --weight_decay 0.0001 --lr_decay_iter 100,150

以下のエラーが発生しました。

Traceback (most recent call last):
  File "src\train.py", line 143, in <module>
    cifar_trainer.fit(train_x, train_y, valid_x, valid_y, test_x, test_y, on_epoch_done)
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 26, in fit
    return self.__fit(x, y, valid_x, valid_y, test_x, test_y, callback)
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 40, in __fit
    loss, acc = self.__forward(x_batch, y[batch_index])
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 75, in __forward
    y = self.net(x, train=train)
  File "c:\project_2\chainer\chainer-cifar\src\net.py", line 360, in __call__
    h = self.bconv1_1(x, train)
  File "c:\project_2\chainer\chainer-cifar\src\net.py", line 28, in __call__
    h = self.bn(self.conv(x), test=not train)
TypeError: __call__() got an unexpected keyword argument 'test'

chainer.links.BatchNormalization__call__の引数にtestがないにもかかわらず渡しているというエラーです。

Chainer v2で動作するように修正する

chainer.functions.dropoutの呼び出し引数からtrainを削除

Chainer v2からはdropoutの引数trainが不要になるので削除します。

修正例:

修正前:
h = F.dropout(F.max_pooling_2d(h, 2), 0.25, train=train)
修正後:
h = F.dropout(F.max_pooling_2d(h, 2), 0.25)

chainer.links.BatchNormalizationの呼び出し引数からtestを削除

BatchNormalizationの引数testが不要になるのでdropoutの場合と同様に削除します。

修正前:

class BatchConv2D(chainer.Chain):
    def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
        super(BatchConv2D, self).__init__(
            conv=L.Convolution2D(ch_in, ch_out, ksize, stride, pad),
            bn=L.BatchNormalization(ch_out),
        )
        self.activation=activation

    def __call__(self, x, train):
        h = self.bn(self.conv(x), test=not train)
        if self.activation is None:
            return h
        return self.activation(h)

修正後:

class BatchConv2D(chainer.Chain):
    def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
        super(BatchConv2D, self).__init__(
            conv=L.Convolution2D(ch_in, ch_out, ksize, stride, pad),
            bn=L.BatchNormalization(ch_out),
        )
        self.activation=activation

    def __call__(self, x): # trainを削除
        h = self.bn(self.conv(x)) # testを削除
        if self.activation is None:
            return h
        return self.activation(h)

学習中でない場合の処理をchainer.using_config('train', False)で括る

dropoutBatchNormalizationの呼び出しから引数train, testを削除しました。
このままだとこれらの関数が学習中のモードで動作してしまいます。
Chainer v2からはwith chainer.using_config('train', ):を使って学習中かどうかを制御します。

    with chainer.using_config('train', False):
        # 学習中でない場合の処理(テストデータの精度計算など)

学習中かどうかをchainer.config.trainで区別する

Chainer v2からchainer.configが追加され、学習中かどうか、back propagationが必要かどうかなどをconfigで判断できるようになりました。
私は今まで学習中かどうかを、以下のように自作関数のtrain引数で判定していたのですが、v2からはtrain引数は必要なくconfiguration.config.trainで判定すればよいです。

修正前:

def my_func(x, train=True):
    if train:
        # 学習中の処理
    else:
        # 学習中でない場合の処理

修正後:

def my_func(x):
    if chainer.config.train:
        # 学習中の処理
    else:
        # 学習中でない場合の処理

back propagationが不要な場合にchainer.using_config('train', False)で括る

back propagationが不要な処理をchainer.using_config('train', False)で括ります。
今までchainer.Variable生成時にvolatileフラグをONにしていたケースが該当します。

Chainer v2 alphaでは必要ないが今後(beta以降で)必要になること

chainer.Variableの引数volatileを削除

v2 alphaの段階では残っていますが今後chainer.Variablevolatileは削除される予定です。
volatileの代わりにchainer.using_config('enable_backprop', )で制御することになります。
chainer.functionschainer.linksの呼び出しにVariableではなくNumpy配列、Cupy配列を渡せるようになっているので、Variableの生成処理も削除する選択肢もあると思います。

修正前:

    x = Variable(xp.asarray(batch_x), volatile=Train)

修正後:

    with chainer.using_config('enable_backprop', False):
        x = Variable(xp.asarray(batch_x))

修正後の実行

c:\project_2\chainer-cifar>python src\train.py -g 0 -m vgg -p model\temp -b 100 --iter 200 --lr 0.1 --optimizer sgd --weight_decay 0.0001 --lr_decay_iter 100,150
DEBUG: nvcc STDOUT mod.cu
   ライブラリ C:/Users/user_name/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_58_Stepping_9_GenuineIntel-2.7.11-64/tmpiwxtcf/265abc51f7c376c224983485238ff1a5.lib とオブジェクト C:/Users/user_name/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_58_Stepping_9_GenuineIntel-2.7.11-64/tmpiwxtcf/265abc51f7c376c224983485238ff1a5.exp を作 成中

Using gpu device 0: GeForce GTX 1080 (CNMeM is disabled, cuDNN 5105)
C:\Users\user_name\Anaconda\lib\site-packages\theano-0.8.2-py2.7.egg\theano\sandbox\cuda\__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.
  warnings.warn(warn)
loading dataset...
start training
epoch 0 done
train loss: 2.29680542204 error: 85.5222222221
valid loss: 1.95620539665 error: 81.3800000548
test  loss: 1.95627536774 error: 80.6099999845
test time: 1.04036228008s
elapsed time: 23.5432411172
epoch 1 done
train loss: 1.91133875476 error: 76.8000000185
valid loss: 1.83026596069 error: 73.6399999559
test  loss: 1.8381768012 error: 73.2900000066
test time: 0.993011643337s

Chainer v2にする前からTheano周りでWarningが出ているのですが、動作しているようです。

最後に

Chainer v2向けの修正は難しくはないのですが、dropoutBatchNormalizationの使用箇所が多かったので、その分修正量としては多くなりました。
修正の結果として、いくつかの関数が持っていた引数trainが不要になったのでコードが少しすっきりしました。
v1向けに実装した多くのコードがv2では動かなくなると思うので、v2が正式リリースされた直後は拾ってきたv1向けのコードを動かそうとしても動かないという事象が多く見られる気がします。

23
24
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
23
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?