現象の紹介
ChainerはTrainerがとても便利なので使いたい。
numpy.arange(10)の平均を計算し、ログに出すことを考える。4.5になるのが正しい。
Trainerを用いる際はchainer.reportを使えば良い感じにログに入るので、これを使う。
import chainer
import numpy
# データセット
dataset = numpy.arange(10, dtype=float)
train_iterator = chainer.iterators.SerialIterator(dataset, batch_size=10)
test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=3, repeat=False)  # 10÷3なので余りがある!
# モデルとか
class Model(chainer.Chain):
    def __call__(self, x):
        x = chainer.Variable(x)
        chainer.report({'x_mean': chainer.functions.mean(x)}, self)
        return x
model = Model()
optimizer = chainer.optimizers.SGD().setup(model)
updater = chainer.training.StandardUpdater(train_iterator, optimizer)
# 学習
extensions = [
    chainer.training.extensions.Evaluator(test_iterator, model),
    chainer.training.extensions.LogReport(),
    chainer.training.extensions.PrintReport([
        'iteration',
        'main/x_mean',  # Updaterのxの平均
        'validation/main/x_mean',  # Evaluatorのxの平均
    ]),
]
trainer = chainer.training.Trainer(updater, stop_trigger=(5, 'epoch'), extensions=extensions)
trainer.run()
良さそうに見えるが、実際はUpdaterとEvaluatorで計算結果がずれてしまう。
iteration   main/x_mean  validation/main/x_mean
1           4.5          4.25
2           4.5          4.75
3           4.5          5.08333
4           4.5          4.75
5           4.5          4.41667
原因はすごく単純で、Updaterは10データ全部の平均を取っているのに対し、Evaluatorは10データを3つずつ取って平均したものを平均しているから。最後の1バッチには1データしかなく、重みが変わっちゃうので、値が変わる。
回避策は2つある。
回避策1 重みを指定する
Evaluatorは内部でDictSummaryを使っていて、これは重み付き平均に対応している。
ちょっと直感的ではないが、計算結果と重みをタプルにして与えることができる。
# モデルとか
class Model(chainer.Chain):
    def __call__(self, x):
        x = chainer.Variable(x)
        chainer.report({'x_mean': (chainer.functions.mean(x), len(x))}, self)  # 重み(データ数)と一緒にタプルにする
        return x
こうすると、計算結果もちゃんと補正される。
iteration   main/x_mean  validation/main/x_mean
1           4.5          4.5
2           4.5          4.5
3           4.5          4.5
4           4.5          4.5
5           4.5          4.5
回避策2 テストデータ数÷バッチサイズに余りが出ないようにする
愚直に余りが出ないようにすれば、計算結果はずれない。
test_iterator = chainer.iterators.SerialIterator(dataset, batch_size=5, repeat=False)  # 10÷5なので余りがない!
iteration   main/x_mean  validation/main/x_mean
1           4.5          4.5
2           4.5          4.5
3           4.5          4.5
4           4.5          4.5
5           4.5          4.5
その他
accuracyの遷移具合などを見るときは、人間が目でなんとなく平均すれば良いんけど、
Optunaとかで自動パラメータ調整するときとかは端数によっては偏りが無視できなくなって、
調整に支障が出てきそうだからちゃんと調べてみた。
別にこれはChainerに限った話ではなく、どのフレームワークでも起こり得るので、気をつけていきたい。