LoginSignup
1
2

More than 3 years have passed since last update.

chainerのTrainerのextensionのEvaluatorを普通に使うと、テストデータ数÷バッチサイズに余りがある場合重みを付けて補正しないと計算結果がずれる

Posted at

現象の紹介

ChainerはTrainerがとても便利なので使いたい。
numpy.arange(10)の平均を計算し、ログに出すことを考える。4.5になるのが正しい。
Trainerを用いる際はchainer.reportを使えば良い感じにログに入るので、これを使う。

import chainer
import numpy

# データセット
dataset = numpy.arange(10, dtype=float)
train_iterator = chainer.iterators.SerialIterator(dataset, batch_size=10)
test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=3, repeat=False)  # 10÷3なので余りがある!

# モデルとか
class Model(chainer.Chain):
    def __call__(self, x):
        x = chainer.Variable(x)
        chainer.report({'x_mean': chainer.functions.mean(x)}, self)
        return x

model = Model()
optimizer = chainer.optimizers.SGD().setup(model)
updater = chainer.training.StandardUpdater(train_iterator, optimizer)

# 学習
extensions = [
    chainer.training.extensions.Evaluator(test_iterator, model),
    chainer.training.extensions.LogReport(),
    chainer.training.extensions.PrintReport([
        'iteration',
        'main/x_mean',  # Updaterのxの平均
        'validation/main/x_mean',  # Evaluatorのxの平均
    ]),
]
trainer = chainer.training.Trainer(updater, stop_trigger=(5, 'epoch'), extensions=extensions)
trainer.run()

良さそうに見えるが、実際はUpdaterとEvaluatorで計算結果がずれてしまう。

iteration   main/x_mean  validation/main/x_mean
1           4.5          4.25
2           4.5          4.75
3           4.5          5.08333
4           4.5          4.75
5           4.5          4.41667

原因はすごく単純で、Updaterは10データ全部の平均を取っているのに対し、Evaluatorは10データを3つずつ取って平均したものを平均しているから。最後の1バッチには1データしかなく、重みが変わっちゃうので、値が変わる。

回避策は2つある。

回避策1 重みを指定する

Evaluatorは内部でDictSummaryを使っていて、これは重み付き平均に対応している。
ちょっと直感的ではないが、計算結果と重みをタプルにして与えることができる。

# モデルとか
class Model(chainer.Chain):
    def __call__(self, x):
        x = chainer.Variable(x)
        chainer.report({'x_mean': (chainer.functions.mean(x), len(x))}, self)  # 重み(データ数)と一緒にタプルにする
        return x

こうすると、計算結果もちゃんと補正される。

iteration   main/x_mean  validation/main/x_mean
1           4.5          4.5
2           4.5          4.5
3           4.5          4.5
4           4.5          4.5
5           4.5          4.5

回避策2 テストデータ数÷バッチサイズに余りが出ないようにする

愚直に余りが出ないようにすれば、計算結果はずれない。

test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=5, repeat=False)  # 10÷5なので余りがない!
iteration   main/x_mean  validation/main/x_mean
1           4.5          4.5
2           4.5          4.5
3           4.5          4.5
4           4.5          4.5
5           4.5          4.5

その他

accuracyの遷移具合などを見るときは、人間が目でなんとなく平均すれば良いんけど、
Optunaとかで自動パラメータ調整するときとかは端数によっては偏りが無視できなくなって、
調整に支障が出てきそうだからちゃんと調べてみた。

別にこれはChainerに限った話ではなく、どのフレームワークでも起こり得るので、気をつけていきたい。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2