はじめに
Tensorboardを初めて使いグラフを書きその便利さに感動したので、共有します。
Deep LearningのフレームワークはPyTorchを使用しました。
インストール
Anacondaを使っているので以下のコマンドでTensorboardをインストールします。
conda install tensorboard
コーディング
import numpy as np
from torch.utils.tensorboard import SummaryWriter#グラフを書くSummaryWriterをimport
np.random.seed(1000)
x = np.random.randn(1000)
writer = SummaryWriter(log_dir="./logs")#インスタンス生成 保存するディレクトリも指定
for i in range(1000):
writer.add_scalar("x", x[i], i)#値を書き込む
writer.add_scalar("sin", np.sin(i), i)
writer.close()#閉じる
ファイル名をtensorboard.pyにするとmoduleと被りImportErrorになるので注意しましょう。
解説
簡単にいうと、上のコードでは、ランダムな値を持つ配列とsin関数をプロットしています。
SummaryWriterをimport
from torch.utils.tensorboard import SummaryWriter
Tensorboardでグラフの描画に必要なmoduleであるSummaryWriterをimportします。
インスタンス生成
writer = SummaryWriter(log_dir="./logs")
これはカレントディレクトリにlogsディレクトリを作成し、そのlogsの中にTensorboard用のファイルが保存されます。
値を代入
writer.add_scalar("x", x[i], i)
で配列の値を入れます。
writer.add_scalar(tags, scalar_value, global_step)
となっており、tagsでグラフの名前を指定して、scalar_valueで保存する値を代入、global_stepでグラフの横軸の間隔を指定します。
閉じる
writer.close()
で最後に閉じましょう。
グラフをみる
tb.pyの実行
上記のコードを実行しましょう。グラフが描画されます。
python tb.py
グラフをみる
以下のコマンドを実行しましょう。--logdir=""
で保存したディレクトリを指定しましょう。
今回は./logs
です。
tensorboard --logdir="./logs"
そうすると、以下の文がターミナルに出力されます。
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)
ローカルサーバーが立ち上がるので、ブラウザにhttp://localhost:8000/
と打ちましょう。
chromeでみると、グラフが綺麗にプロットしていることがわかります。
ssh先のグラフをみる
Deep Learningのコードは計算量が多くローカルPC(手元のPC)では莫大な時間がかかるので、
研究室にあるサーバーのGPUにsshしてサーバー上でコードを回すことがデフォルトです。
では、そういう場合はリモートサーバーで描画したグラフをローカルPCでどうやってみるのでしょうか?
リモートサーバーにsshする
sshをする時に、-Lオプションを用いてクライアント(ローカルPC)のlocalhost:9000をリモートサーバーのユーザ名@サーバーのIPアドレス:8000に繋げます。
ssh ユーザ名@サーバーのIPアドレス -L 9000:localhost:8000
リモートサーバーでtb.pyの実行
sshしたリモートサーバーでグラフを描画するコードを実行しましょう。
python tb.py
Tensorboardの実行
sshしたリモートサーバーでグラフをみるためのコマンドを実行しましょう。
sshした際にローカルPCに繋いだポートは8000なので、--portオプションで8000を指定して実行しましょう。
tensorboard --logdir="./logs" --port 8000
以下のような文が出力されます。
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)
グラフをみる
さっきはhttp://localhost:8000/
をブラウザに入力したら、グラフがみれましたが今回は見れません。
今回はリモートサーバーのポート8000とローカルPCのポート9000を繋げたので、
ローカルPCのブラウザでhttp://localhost:9000/
と入力すれば、さっきと同じグラフが見れます。
まとめ
PyTorchでTensorboardを用いてグラフを描画しました。
また、ssh先のリモートサーバーでまわしたコードのグラフをローカルPCでみる方法を紹介しました。
私もこのTensorboardとssh -Lを利用してDeep Learningに活用していきたいと思います。