12
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainer v1からChainer v2への移行

Last updated at Posted at 2017-08-03

Chainerが2.0へアップデートされて後方互換性がなくなったので,AlexNetを例に2.0.2で動かせるように修正した部分を書いていきます.

ネットワーク部分

主にこの部分が変更になり,それ以外は特に変更する必要がないと思います.
v1.xの際は__init__部分で直接class継承を行なっていましたが,v2.xではwith self.init_scope():を用いるように変更になりました.
それに伴いinit部分の記法が変わっています.これまで fc_n = L.function( ), と記述していたのに対して, self.fc_n = L.function( )というようにself.を追加し,最後のが必要なくなります.

ここのを削除し忘れるとTypeError: 'tuple' object is not callable となるので注意.
このエラーでデータセットの方が悪いのかと思ってどハマりしました.

v1.x_Alex.py
class Alex(chainer.Chain):
  def __init__(self):
    super(Alex, self).__init__(
      conv1=L.Convolution2D(None,  96, 11, stride=4),
      conv2=L.Convolution2D(None, 256,  5, pad=2),
      conv3=L.Convolution2D(None, 384,  3, pad=1),
      conv4=L.Convolution2D(None, 384,  3, pad=1),
      conv5=L.Convolution2D(None, 256,  3, pad=1),
      fc6=L.Linear(None, 4096),
      fc7=L.Linear(None, 4096),
      fc8=L.Linear(None, 1000),
    )
  self.train = True
...
v2.x_Alex.py
class Alex(chainer.Chain):
  def __init__(self):
    super(Alex, self).__init__()
    with self.init_scope():
      self.conv1 = L.Convolution2D(None,  96, 11, stride=4)
      self.conv2 = L.Convolution2D(None, 256,  5, pad=2)
      self.conv3 = L.Convolution2D(None, 384,  3, pad=1)
      self.conv4 = L.Convolution2D(None, 384,  3, pad=1)
      self.conv5 = L.Convolution2D(None, 256,  3, pad=1)
      self.fc6 = L.Linear(None, 4096)
      self.fc7 = L.Linear(None, 4096)
      self.fc8 = L.Linear(None, 1000)
...

また,学習時にのみ反映させるtrainが関数内から削除され,with chainer.using_config()で指定しています.

model = L.Classifier(Alex())
with chainer.using_config('train',False):
  y = (x)

基本的にこのchainer.configで設定を行うのが共通になるらしく,usecudnnもこれで設定することになっています.
chainer.config.use_cudnn = 'never' と記述すると一括で使用しない設定にできます.

反復部分

変えないと動かない部分はありませんが,trainer.extendの種類が増えています.

if extensions.PlotReport.available():
  trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
                 'epoch', file_name='loss.png'))
  trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

と記述することで自動でlossとaccuracyのグラフが出力されるようですが,自分の環境では入っているはずのmatplotlibがないと言われてうまく動きませんでした.

動かしてみて

変更して実際に動かして見た感じですが,一般的なLinearだけで構成されているfeed-forwardネットワークは特に今までと変わらなかったのですが,AlexNetをベースにしたネットワークでは極端に学習が遅くなっている気がします.
おそらく書き方が悪いのかと…

こちらの公式にその他の変更点も書かれています
Chainer docs

12
10
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?