Python
matplotlib
Chainer

Chainerのextensions.PlotReportをログ表示に変更する

Chainerにはextensions.PlotReport()というextensions.LogReportで出力されるログ結果を可視化してくれる機能がある。しかし、学習がある程度進むとlossもほとんど変化しなくなるので対数グラフで可視化したくなる。最近リリースされたchaineruiを利用するのも手だが、extensions.PlotReport()を少し変更するだけで対数グラフで出力できるようになる。

動作環境

  • Ubuntu 16.04.3 LTS
  • Python 3.5.2
  • chainer 3.2

グラフのイメージ

以下で説明する関数を他のextensions同様にimportして使用すれば以下のように対数グラフのPlotReportが生成される。

log_plot.png

対数グラフにする

ソースコードの取得

まずはextensions.PlotReport()をchainerのソースコードから持ってくる。

ソースコードの改変

変更前

def __call__(self, trainer)の中間程度の以下を

f = plt.figure()
a = f.add_subplot(111)
a.set_xlabel(self._x_key)
if self._grid:
    a.grid()

for k in keys:
    xy = data[k]
    if len(xy) == 0:
        continue

    xy = numpy.array(xy)
    plt.yscale("log")
    a.plot(xy[:, 0], xy[:, 1], marker=self._marker, label=k)

if a.has_data():
    if self._postprocess is not None:
        self._postprocess(f, a, summary)

変更後

以下のように修正する(# 追加の部分)。最後のannotate()はお好みで。

f = plt.figure()
a = f.add_subplot(111)
a.set_xlabel(self._x_key)
if self._grid:
     a.grid(which='major', color='gray', linestyle=':')# 追加
     a.grid(which='minor', color='gray', linestyle=':')# 追加
    # a.grid()は削除

for k in keys:
    xy = data[k]
    if len(xy) == 0:
        continue

    xy = numpy.array(xy)
    plt.yscale("log")# 追加
    a.plot(xy[:, 0], xy[:, 1], marker=self._marker, label=k)

if a.has_data():
    if self._postprocess is not None:
        self._postprocess(f, a, summary)

    # 追加(validationの最新の値を表示)
    a.annotate('validation\n{0:8.6f}'.format(xy[-1, 1]),
               xy=(xy[-1]), xycoords='data',
               xytext=(-90, 75), textcoords='offset points',
               bbox=dict(boxstyle="round", fc="0.8"),
               arrowprops=dict(arrowstyle="->",
                               connectionstyle="arc,angleA=0,armA=50,rad=10"))

できれば標準で実装して欲しい(PlotReportflg指定する感じ)けど、github初心者なので誰か代わりに報告してください

以上。