16
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainerのレポーティング機能についてまとめてみる

Last updated at Posted at 2017-06-26

tl;dr

trainer.extend(extensions.LogReport())を書けば最低限のロギングがただで得られる。extensions.ParameterStatisticsをしっておくと幸せになれる。

目的

2016年6月、v1.11でChainerにTraining Loop Abstractionが導入されてから早一年ですが、オレオレ学習ループの資産(負債?)を使いたくて、手をだしていないという人も多いのではないでしょうか。そろそろいい加減にTrainerやらなんやらを使いこなしたいということで、あまり記事がヒットしない(しかし、とても重要な)学習メトリクスのレポーティングについてまとめました。

Chainerにおけるレポーティングメカニズムを理解する

深層学習をやるうえで、各種メトリクスを監視することはとても重要です。たとえば、Tensorflowにはsummaryという強力なレポーティング機能がついています。

公式にはいかにもリポーティングをしてくれそうなReporterというクラスがドキュメンテーションされていますが、公式のExampleを見るとReporterというクラスはどこにもでてきませんし、精度などを明示的に書き出しているところもありません。これはどういうことなのでしょう。

公式ExampleのL96-97を見てみると、LogReportPrintReportなどのextensionsTrainerオブジェクトに付加しています。

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

そして、chainer/chainer/training/trainer.py:L291-L299をみてみると、学習ループの始まりでwith reporter.scope(self.observation)と宣言しています。この宣言により学習ループ内で呼び出したchainer.reporter.report({'name': value_to_report})の呼び出しはすべてself.observationに格納されるようになります。

    def run(self):
        ....
        reporter = self.reporter
        stop_trigger = self.stop_trigger

        # main training loop
        try:
            while not stop_trigger(self):
                self.observation = {}
                with reporter.scope(self.observation):
                    update()
                    for name, entry in extensions:
                        if entry.trigger(self):
                             entry.extension(self)

つまり、特にReporterを明示的によばなくても、実はメトリクスは裏側でちゃんと集めているということになります。さて、集めたデータですが、entry.extension(self)のよびだしによって、chainer/chainer/training/extensions/log_report.py:L67-L88に渡されます。

    def __call__(self, trainer):
        # accumulate the observations
        keys = self._keys
        observation = trainer.observation
        summary = self._summary

        if keys is None:
            summary.add(observation)
        else:
            summary.add({k: observation[k] for k in keys if k in observation})

        if self._trigger(trainer):
            # output the result
            stats = self._summary.compute_mean()
            stats_cpu = {}
            for name, value in six.iteritems(stats):
                stats_cpu[name] = float(value)  # copy to CPU

            updater = trainer.updater
            stats_cpu['epoch'] = updater.epoch
            stats_cpu['iteration'] = updater.iteration
            stats_cpu['elapsed_time'] = trainer.elapsed_time

この関数内でエポック数などが付記されて、然るべき場所に出力されます。何らかのレポーティング機能を持つextensionが登録されていなければ、データは単に破棄されます。

ここで、公式のExampleでReporterを明示的に読んでいない理由がわかりました。しかし、chainer.reporter.reportを一度もよびだしていないのに、精度(accuracy)や損失(loss)はどうしてなんでしょうか。

そこで、chainer/chainer/links/model/classifier.pyをみてみると、公式の実装内でchainer.reporter.reportがよびだされていることがわかります。

        self.loss = self.lossfun(self.y, t)
        reporter.report({'loss': self.loss}, self)
        if self.compute_accuracy:
            self.accuracy = self.accfun(self.y, t)
            reporter.report({'accuracy': self.accuracy}, self)

つまり、trainer.extend(extensions.LogReport())とだけ書けば必要最低限のロギングは得られ、あとは自分のモデルの中でchainer.reporter.reportとさえ呼べば任意のレポーティングができるわけですね。便利です。

ちなみに、上記exampleを実行すると下記のようなレポーティングがresult/logに得られます。

[{u'elapsed_time': 6.940603971481323,
  u'epoch': 1,
  u'iteration': 600,
  u'main/accuracy': 0.9213500021273892,
  u'main/loss': 0.2787705701092879,
  u'validation/main/accuracy': 0.9598000049591064,
  u'validation/main/loss': 0.13582063710317016},
 {u'elapsed_time': 14.360282897949219,
  u'epoch': 2,
  u'iteration': 1200,
  ...

レポーティング内容を増やしてみる

これだけで十分便利ですが、extensions.ParameterStatisticsを使うことでTensorflowのtf.summary.histogramのようなリッチなモニタリングができます。

...
trainer.extend(extensions.ParameterStatistics(model))
...

結果にはモデルに含まれる各Linkの行列の代表値が自動的に集められ、追加されます。大変便利ですね。

[{u'None/predictor/l1/W/data/max': 0.18769985591371854,
  u'None/predictor/l1/W/data/mean': 0.0006860141372822189,
  u'None/predictor/l1/W/data/min': -0.21658104345202445,
  u'None/predictor/l1/W/data/percentile/0': -0.1320047355272498,
  u'None/predictor/l1/W/data/percentile/1': -0.08497818301255008,
  u'None/predictor/l1/W/data/percentile/2': -0.04122352957670082,
  u'None/predictor/l1/W/data/percentile/3': 0.0008963784146650747,
  u'None/predictor/l1/W/data/percentile/4': 0.0428067545834066,
  ...

上記の実行結果はgistにおいてあります。

16
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?