24
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pytorchでtensorboardを利用する

Posted at

はじめに

pytorchでの学習結果をtensorboardで確認するための手順。

環境

python==3.6.3
pytorch==0.4.0

tensorflowとtensorboardXのインストール

pip install tensorflow
pip install tensorboardx

tensorboardに食わせるログの設定の仕方

import torch
from torch.utils.data import DataLoader
import tensorboardX as tbx # tensorboardXのインポート[ポイント1]
~(中略)~

# model definition
model = Darknet(opt.model_config_path)
# data definition
dataloader = torch.utils.data.DataLoader(
    ListDataset(train_path), batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu
)
~(中略)~

# SummaryWriterのインスタンス作成[ポイント2]
writer = tbx.SummaryWriter()

# 学習スタート
for epoch in range(opt.epochs):
    for batch_i, (_, imgs, targets) in enumerate(dataloader):
        imgs = Variable(imgs.type(Tensor))
        targets = Variable(targets.type(Tensor), requires_grad=False)
        optimizer.zero_grad()
        loss = model(imgs, targets)
        loss.backward()
        optimizer.step()

        # tensorboard用log出力設定1[ポイント3]
        writer.add_scalar('data/total_loss', loss.item(), (epoch + 1) * batch_i)
        # tensorboard用log出力設定2[ポイント4]
        writer.add_scalars('data/loss',
                           {
                               'x': model.losses["x"],
                               'y': model.losses["y"],
                               'w': model.losses["w"],
                               'h': model.losses["h"],
                               'conf': model.losses["conf"],
                               'cls': model.losses["cls"]},
                           (epoch + 1) * batch_i)

    # tensorboard用log出力設定3[ポイント5]
    writer.add_scalars('data/metrics',
                    {
                        'recall': model.losses["recall"],
                        'precision': model.losses["precision"]},
                    (epoch))

# tensorboard用の値のjsonファイルへの保存[ポイント6]
writer.export_scalars_to_json("./all_scalars.json")
# SummaryWriterのclose[ポイント7]
writer.close()

得られるtensorboardの画面

スクリーンショット 2019-01-08 16.58.56.png

ポイントの解説

ポイント1:tensorboardXのインポート

単純にインポートするだけ。ここでエラーが出たら、pipでtensorboardXをインストール。

ポイント2:SummaryWriterのインスタンス作成

writer = tensorboardX.SummaryWriter()の形で作成し、このwriterにiterationごとに出力していく。

ポイント3, 4

writer.add_scalar(名称, 保存するデータ, iteration数)でwriterにaddしていく。
writer.add_scalar()は、単一のデータを、writer.add_scalars()は保存するデータにdict形式で指定することで、複数の値を一度に保存できる。
名称をfoo/bar形式にすることで、tensorboardのログの表示のまとまりがfoo単位となり、その下にbar1, bar2と連なる形で出力ができる。

ポイント5

writer.add_scalerをまたしているが、名称をlossからmetricsに変えたことで、表示のまとまりが別となり、lossのかたまりとmetricsのかたまりの別に別れて出力することを明示している。

ポイント6

tensorboardの値をjsonで保存する。機会があるかどうか不明だが、このjsonを保存しておいて、lossの最小値のモデルの選択などが可能となる。

ポイント7

処理終了後にwriter.close()で終了する。

これにより、実行スクリプトの配下にruns/というディレクトリができ、以下のコマンドでtensorboardで上記の値の確認ができる。

tensorboard --logdir runs/

その他

詰まったポイント

tensorboard --logdir runs/を実行したら、以下のエラーが発生した。

RuntimeError: module compiled against API version 0xc but this version of numpy is 0xb
ImportError: numpy.core.multiarray failed to import
ImportError: numpy.core.umath failed to import
ImportError: numpy.core.umath failed to import
2019-01-08 15:21:35.905720: F tensorflow/python/lib/core/bfloat16.cc:675] Check failed: PyBfloat16_Type.tp_base != nullptr
Abort trap: 6

pip uninstall numpy
pip install --no-cache-dir numpy==1.14.5
で解決した。

画像などその他情報のtensorboard保存

writer.add_image('Image', x, n_iter)でimageを保存できそう。
その他にも、writer.add_audio()やwriter.add_text()などが存在する。

24
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?