LoginSignup
5
5

More than 3 years have passed since last update.

機械学習の可視化ツールTensorboardを使ってみた

Posted at

はじめに

Tensorboardを初めて使いグラフを書きその便利さに感動したので、共有します。
Deep LearningのフレームワークはPyTorchを使用しました。

インストール

Anacondaを使っているので以下のコマンドでTensorboardをインストールします。

conda install tensorboard

コーディング

tb.py
import numpy as np
from torch.utils.tensorboard import SummaryWriter#グラフを書くSummaryWriterをimport

np.random.seed(1000)

x = np.random.randn(1000)

writer = SummaryWriter(log_dir="./logs")#インスタンス生成 保存するディレクトリも指定

for i in range(1000):
    writer.add_scalar("x", x[i], i)#値を書き込む
    writer.add_scalar("sin", np.sin(i), i)

writer.close()#閉じる

ファイル名をtensorboard.pyにするとmoduleと被りImportErrorになるので注意しましょう。

解説

簡単にいうと、上のコードでは、ランダムな値を持つ配列とsin関数をプロットしています。

SummaryWriterをimport

from torch.utils.tensorboard import SummaryWriter
Tensorboardでグラフの描画に必要なmoduleであるSummaryWriterをimportします。

インスタンス生成

writer = SummaryWriter(log_dir="./logs")
これはカレントディレクトリにlogsディレクトリを作成し、そのlogsの中にTensorboard用のファイルが保存されます。

値を代入

writer.add_scalar("x", x[i], i)で配列の値を入れます。
writer.add_scalar(tags, scalar_value, global_step)となっており、tagsでグラフの名前を指定して、scalar_valueで保存する値を代入、global_stepでグラフの横軸の間隔を指定します。

閉じる

writer.close()で最後に閉じましょう。

グラフをみる

tb.pyの実行

上記のコードを実行しましょう。グラフが描画されます。

python tb.py

グラフをみる

以下のコマンドを実行しましょう。--logdir=""で保存したディレクトリを指定しましょう。
今回は./logsです。

tensorboard --logdir="./logs"

そうすると、以下の文がターミナルに出力されます。

TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

ローカルサーバーが立ち上がるので、ブラウザにhttp://localhost:8000/と打ちましょう。

スクリーンショット 2020-08-12 22.48.19.png

chromeでみると、グラフが綺麗にプロットしていることがわかります。

ssh先のグラフをみる

Deep Learningのコードは計算量が多くローカルPC(手元のPC)では莫大な時間がかかるので、
研究室にあるサーバーのGPUにsshしてサーバー上でコードを回すことがデフォルトです。
では、そういう場合はリモートサーバーで描画したグラフをローカルPCでどうやってみるのでしょうか?

リモートサーバーにsshする

sshをする時に、-Lオプションを用いてクライアント(ローカルPC)のlocalhost:9000をリモートサーバーのユーザ名@サーバーのIPアドレス:8000に繋げます。

@ローカルPC
ssh ユーザ名@サーバーのIPアドレス -L 9000:localhost:8000

リモートサーバーでtb.pyの実行

sshしたリモートサーバーでグラフを描画するコードを実行しましょう。

@リモートサーバー
python tb.py

Tensorboardの実行

sshしたリモートサーバーでグラフをみるためのコマンドを実行しましょう。
sshした際にローカルPCに繋いだポートは8000なので、--portオプションで8000を指定して実行しましょう。

@リモートサーバー
tensorboard --logdir="./logs" --port 8000

以下のような文が出力されます。

@リモートサーバー
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

グラフをみる

さっきはhttp://localhost:8000/をブラウザに入力したら、グラフがみれましたが今回は見れません。

今回はリモートサーバーのポート8000とローカルPCのポート9000を繋げたので、
ローカルPCのブラウザでhttp://localhost:9000/と入力すれば、さっきと同じグラフが見れます。

スクリーンショット 2020-08-12 22.48.19.png

まとめ

PyTorchでTensorboardを用いてグラフを描画しました。
また、ssh先のリモートサーバーでまわしたコードのグラフをローカルPCでみる方法を紹介しました。
私もこのTensorboardとssh -Lを利用してDeep Learningに活用していきたいと思います。

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5