はじめに
Tensorboardでmaplotlibのプロットを表示する方法が欲しかったため考えてみた。
Tensorboardでは画像を表示することが出来るので、プロットを画像化して表示すればよいのでは?と考えた。
ということで、その方法をまとめます。
なお、ログの出力はtensorboardX
を介して行っている点に注意してください。
方法
まず、matplotlibでプロットを画像で得るため、バックエンドをAgg
に切り替える。
バックエンドの切り替えはimport matplotlib.pyplot
の前に行う必要がある点に注意。
import matplotlib as mpl
mpl.use('Agg')
プロットの画像データは以下の方法で取得できる。
import matplot.pyplot as plt
import numpy as np
fig = plt.figure() # 繰り返し表示する際にはplt.figure(0)等でFigureを指定したほうが良い
# 何かしらのプロット
fig.canvas.draw() # Canvasに描画
plot_image = fig.canvas.renderer._renderer # プロットを画像データとして取得
# tensorboardXはchannel firstのようなので、それに合わせる
plot_image_array = np.array(plot_image).transpose(2, 0, 1)
あとはtensorboardX
でログを出力するだけ。
from tensorboardX import SummaryWriter
summary_writer = SummaryWriter(logdir='hoge') # writerの初期化
summary_writer.add_image('plot', plot_image_array) # 画像の追加
以上です。
その他に良い方法があれば是非教えてください。