LoginSignup
8

More than 1 year has passed since last update.

【Keras入門(3)】TensorBoardで見える化

Last updated at Posted at 2019-05-28

入門者に向けてKerasの初歩を解説します。
TensorBoardも含めてGoogle Colaboratoryを使っているのでローカルでの環境準備すらしていません。Google Colaboratoryについては「Google Colaboratory概要と使用手順(TensorFlowもGPUも使える)」の記事を参照ください。

以下のシリーズにしています。

また、TensorBoardに関しては以下の記事を参照ください。

使ったPythonライブラリ

Google Colaboratoryでインストール済の以下のライブラリとバージョンを使っています。KerasはTensorFlowに統合されているものを使っているので、ピュアなKerasは使っていません。Pythonは3.6です。

  • tensorflow: 1.13.1
  • Numpy: 1.16.3

Pythonプログラム

プログラム全体はGitHubにあります。ディープラーニングのモデルについては、記事「【Keras入門(1)】単純なディープラーニングモデル定義を参照ください。

コールバックとTensorBoardログ保存

KerasでTensorBoardログを保存する場合には、コールバックという機能を使用します。これは、訓練中に呼び出すことができる仕組みです。
記事「【Keras入門(1)】単純なディープラーニングモデル定義」で使用したfit関数に渡すことで使用できます。

model.fit(data, labels, epochs=300, validation_split=0.2, callbacks=li_cb)

経験ないですが、fit_generator関数にも使えるようです。

fit関数に渡しているli_cbという変数は以下のように定義しています。配列にすることで複数の機能をコールバックで使用できます。今回は、TensorBoard関数を使ったTensorBoardログの保存です。保存先はタイムスタンプを付加したフォルダ名にしておきます。何度も実行した場合に見分けやすくて便利です。

from datetime import datetime
from tensorflow.keras.callbacks import Callback, TensorBoard

# TensorBoardのログ保存先(タイムスタンプを付けておくと見るときに便利)
logdir = "log/run-{}/".format(datetime.utcnow().strftime("%Y%m%d%H%M%S"))

# CallBackの指定
li_cb = []
li_cb.append(TensorBoard(log_dir=logdir, histogram_freq=1, write_graph=True, write_grads=True))

Google Colaboratory上でTensorBoard表示

TensorFlow2.0だと、簡単にGoogle Colaboratory上でTensorBoardが使えるようですが、2019年5月時点でのデフォルトのTensorFlow1.13.1では使えないのでひと手間加えます。

ngrokという公開URLで表示させてくれるサービスを使ってTensorBoardを公開させます。
まずは、wgetで取得してunzipします。

!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip

tensorboardを起動してngrokで公開します。

get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(logdir)
)
# Tunnel port 6006 (TensorBoard assumed running)
get_ipython().system_raw('./ngrok http 6006 &')

cURLで公開URL情報を取得。URLを開けばTensorBoard画面です!
もちろん、ログをダウンロードしてローカルPCで見ることもできます。

# ここで表示されたURLを開く
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

03.tensorboard.JPG

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
8