現象の紹介
ChainerはTrainerがとても便利なので使いたい。
numpy.arange(10)
の平均を計算し、ログに出すことを考える。4.5
になるのが正しい。
Trainerを用いる際はchainer.report
を使えば良い感じにログに入るので、これを使う。
import chainer
import numpy
# データセット
dataset = numpy.arange(10, dtype=float)
train_iterator = chainer.iterators.SerialIterator(dataset, batch_size=10)
test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=3, repeat=False) # 10÷3なので余りがある!
# モデルとか
class Model(chainer.Chain):
def __call__(self, x):
x = chainer.Variable(x)
chainer.report({'x_mean': chainer.functions.mean(x)}, self)
return x
model = Model()
optimizer = chainer.optimizers.SGD().setup(model)
updater = chainer.training.StandardUpdater(train_iterator, optimizer)
# 学習
extensions = [
chainer.training.extensions.Evaluator(test_iterator, model),
chainer.training.extensions.LogReport(),
chainer.training.extensions.PrintReport([
'iteration',
'main/x_mean', # Updaterのxの平均
'validation/main/x_mean', # Evaluatorのxの平均
]),
]
trainer = chainer.training.Trainer(updater, stop_trigger=(5, 'epoch'), extensions=extensions)
trainer.run()
良さそうに見えるが、実際はUpdaterとEvaluatorで計算結果がずれてしまう。
iteration main/x_mean validation/main/x_mean
1 4.5 4.25
2 4.5 4.75
3 4.5 5.08333
4 4.5 4.75
5 4.5 4.41667
原因はすごく単純で、Updaterは10データ全部の平均を取っているのに対し、Evaluatorは10データを3つずつ取って平均したものを平均しているから。最後の1バッチには1データしかなく、重みが変わっちゃうので、値が変わる。
回避策は2つある。
回避策1 重みを指定する
Evaluatorは内部でDictSummaryを使っていて、これは重み付き平均に対応している。
ちょっと直感的ではないが、計算結果と重みをタプルにして与えることができる。
# モデルとか
class Model(chainer.Chain):
def __call__(self, x):
x = chainer.Variable(x)
chainer.report({'x_mean': (chainer.functions.mean(x), len(x))}, self) # 重み(データ数)と一緒にタプルにする
return x
こうすると、計算結果もちゃんと補正される。
iteration main/x_mean validation/main/x_mean
1 4.5 4.5
2 4.5 4.5
3 4.5 4.5
4 4.5 4.5
5 4.5 4.5
回避策2 テストデータ数÷バッチサイズに余りが出ないようにする
愚直に余りが出ないようにすれば、計算結果はずれない。
test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=5, repeat=False) # 10÷5なので余りがない!
iteration main/x_mean validation/main/x_mean
1 4.5 4.5
2 4.5 4.5
3 4.5 4.5
4 4.5 4.5
5 4.5 4.5
その他
accuracyの遷移具合などを見るときは、人間が目でなんとなく平均すれば良いんけど、
Optunaとかで自動パラメータ調整するときとかは端数によっては偏りが無視できなくなって、
調整に支障が出てきそうだからちゃんと調べてみた。
別にこれはChainerに限った話ではなく、どのフレームワークでも起こり得るので、気をつけていきたい。