Chainerが2.0へアップデートされて後方互換性がなくなったので,AlexNetを例に2.0.2で動かせるように修正した部分を書いていきます.
ネットワーク部分
主にこの部分が変更になり,それ以外は特に変更する必要がないと思います.
v1.xの際は__init__部分で直接class継承を行なっていましたが,v2.xではwith self.init_scope():を用いるように変更になりました.
それに伴いinit部分の記法が変わっています.これまで fc_n = L.function( ),
と記述していたのに対して, self.fc_n = L.function( )
というようにself.
を追加し,最後の,
が必要なくなります.
ここの,
を削除し忘れるとTypeError: 'tuple' object is not callable
となるので注意.
このエラーでデータセットの方が悪いのかと思ってどハマりしました.
class Alex(chainer.Chain):
def __init__(self):
super(Alex, self).__init__(
conv1=L.Convolution2D(None, 96, 11, stride=4),
conv2=L.Convolution2D(None, 256, 5, pad=2),
conv3=L.Convolution2D(None, 384, 3, pad=1),
conv4=L.Convolution2D(None, 384, 3, pad=1),
conv5=L.Convolution2D(None, 256, 3, pad=1),
fc6=L.Linear(None, 4096),
fc7=L.Linear(None, 4096),
fc8=L.Linear(None, 1000),
)
self.train = True
...
class Alex(chainer.Chain):
def __init__(self):
super(Alex, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(None, 96, 11, stride=4)
self.conv2 = L.Convolution2D(None, 256, 5, pad=2)
self.conv3 = L.Convolution2D(None, 384, 3, pad=1)
self.conv4 = L.Convolution2D(None, 384, 3, pad=1)
self.conv5 = L.Convolution2D(None, 256, 3, pad=1)
self.fc6 = L.Linear(None, 4096)
self.fc7 = L.Linear(None, 4096)
self.fc8 = L.Linear(None, 1000)
...
また,学習時にのみ反映させるtrainが関数内から削除され,with chainer.using_config()
で指定しています.
model = L.Classifier(Alex())
with chainer.using_config('train',False):
y = (x)
基本的にこのchainer.config
で設定を行うのが共通になるらしく,usecudnn
もこれで設定することになっています.
chainer.config.use_cudnn = 'never'
と記述すると一括で使用しない設定にできます.
反復部分
変えないと動かない部分はありませんが,trainer.extendの種類が増えています.
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))
と記述することで自動でlossとaccuracyのグラフが出力されるようですが,自分の環境では入っているはずのmatplotlibがないと言われてうまく動きませんでした.
動かしてみて
変更して実際に動かして見た感じですが,一般的なLinearだけで構成されているfeed-forwardネットワークは特に今までと変わらなかったのですが,AlexNetをベースにしたネットワークでは極端に学習が遅くなっている気がします.
おそらく書き方が悪いのかと…
こちらの公式にその他の変更点も書かれています
Chainer docs