197
191

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasのcallbackを試す(modelのsave,restore/TensorBoard書き出し/early stopping)

Last updated at Posted at 2016-06-30

はじめに

KerasはTheano,TensorFlowベースの深層学習ラッパーライブラリです.大まかな使い方は以前記事を書いたので興味のある方はそちらをごらんください.Kerasにはいくつか便利なcallbackが用意されており,modelやparameterを書き出すタイミングやTensorBoardへのログを吐き出すタイミングを指定することができます(公式サイト).今回はそれらのcallbackを,実際に試しながらみていきます.

基本的な使い方

1. コールバックの作成

es_cb = keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')
tb_cb = keras.callbacks.TensorBoard(log_dir=log_filepath, histogram_freq=1)

まずはコールバックを作成します.次説で簡単に解説しますが,Kerasにはデフォルトで何種類かのコールバックが用意されています.上の例では,学習が収束した際に途中で学習を打ち切る用のコールバックと,TensorFlowのTensorBoardに書き出す用のコールバックを作成しています.

2. コールバックのfit()への設定

1.で作成したコールバック関数は,model.fit()を呼び出す際に,下記のように登録します.

model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[es_cb, tb_cb])

collbacksに,登録したいコールバック関数を配列形式で渡します.コールバックの呼ばれるタイミングは関数の種類によって異なります.

コールバックの種類

EarlyStopping

学習ループに収束判定を付与することができます.監視する値を設定し,それが収束したら自動的にループを抜ける処理になります.

keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')
arguments description
monitor 監視値指定.例えば,monitor='val_loss'
patience ループの最低数
verbose 保存時に標準出力にコメントを出すか指定.{0, 1}
mode 上限,下限どちらの側に収束した場合に収束判定を出すかの規定.{auto, min, max}

上記の設定で,以下のように学習ループ途中であっても収束判定がかかり,ループから抜けることができます(下記の表示はverbose=1に設定していた場合).

Epoch 5/15
3325/3325 [==============================] - 23s - loss: 0.0138 - val_loss: 0.0070
Epoch 6/15
3000/3325 [==========================>...] - ETA: 2s - loss: 0.0088
Epoch 00005: early stopping
3325/3325 [==============================] - 25s - loss: 0.0087 - val_loss: 0.0089

TensorBoard

TensorFlowの可視化ツールであるTensorboardを使うための関数です.学習の直前と毎epochの終了時に呼び出されます.

keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True)
arguments description
log_dir logファイルを書き出すディレクトリの指定
histogram_freq Tensorboardのhistogram用データを出力する頻度の指定.histogram_freq=1の場合は毎epochデータが出力される
write_graph モデルのグラフを出力するか否かの指定

裏では,学習の直前にtf.merge_all_summariestf.train.SummaryWriterが呼ばれ,毎epochの終わりにadd_summaryが呼ばれてlogが出力されています.**TensorBoardにlogを出力する場合には,keras.backend.tensorflow_backendを利用して明示的にtensorflowのセッションを登録する必要があります.**詳しくは下記の例を参照してください.

ModelCheckpoint

Kerasにおける,model,parameterの保存,読み込みはjson, yaml形式で行います.

modelの保存,読み込みはmodel.to_json()/model.to_yaml()model_from_json()/model_from_yaml()を使用します.

from keras.models import Sequential, model_from_json

json_string = model.to_json()
model = model_from_json(json_string)

学習したParameterの保存&読み込みは,save_weights/load_weightsを使用します(h5pyが必要).

model.save_weights('param.hdf5')
model.load_weights('param.hdf5')

Parameterの保存に使用するコールバック関数はModelCheckpointです. この関数は毎epochの終わりで呼ばれます

arguments description
filepath 保存ファイル名
monitor 監視値指定.例えば,monitor='val_loss'
verbose 保存時に標準出力にコメントを出すか指定.{0, 1}
save_best_only 精度がよくなった時だけ保存するかどうか指定.Falseの場合は毎epoch保存.
mode 上限,下限どちらの側に収束した場合に収束判定を出すかの規定.{auto, min, max}

filepathが同じ名前場合上書きされるので,名前を変えるために指定した変数の値を自動入力してくれる機能が備わっています.
指定できる変数は、epoch, loss, acc, val_loss, val_accです.
例えばfilepathを下記のように指定した場合,その時の値を自動で入れてくれます.

fpath = 'weights.{epoch:02d}-{loss:.2f}-{acc:.2f}-{val_loss:.2f}-{val_acc:.2f}.hdf5'
cp_cb = keras.callbacks.ModelCheckpoint(filepath=fpath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

LearningRateScheduler

学習係数を動的に変更することができる関数です.

例えば,

start = 0.03
stop = 0.001
nb_epoch = 1000
learning_rates = np.linspace(start, stop, nb_epoch)

lr_cb = keras.callbacks.LearningRateScheduler(lambda epoch: float(learning_rates[epoch]))

のように指定することで,epochに応じて学習係数を変動させることができます.

実際に試してみる

上記のコールバックを実際にRNNにsin波を学習させるサンプル上で試してみます.コード自体の詳しい解説はリンクを参照してください.

まず,学習,テストデータを作ります.

import pandas as pd
import numpy as np
import math
import random
random.seed(0)
random_factor = 0.05
steps_per_cycle = 80
number_of_cycles = 50

df = pd.DataFrame(np.arange(steps_per_cycle * number_of_cycles + 1), columns=["t"])
df["sin_t"] = df.t.apply(lambda x: math.sin(x * (2 * math.pi / steps_per_cycle)+ random.uniform(-1.0, +1.0) * random_factor))
(X_train, y_train), (X_test, y_test) = train_test_split(df[["sin_t"]], n_prev =length_of_sequences)  

続いて,kerasでモデルを作り,学習を回します.

from keras.models import Sequential  
from keras.layers.core import Dense, Activation  
from keras.layers.recurrent import LSTM
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
from keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
import os

in_out_neurons = 1
hidden_neurons = 300
length_of_sequences = 100

old_session = KTF.get_session()

with tf.Graph().as_default():
    session = tf.Session('')
    KTF.set_session(session)
    KTF.set_learning_phase(1)
    model = Sequential()  
    with tf.name_scope("inference") as scope:
        model.add(LSTM(hidden_neurons, input_shape=(length_of_sequences, in_out_neurons), return_sequences=False))  
        model.add(Dense(in_out_neurons))  
        model.add(Activation("linear"))       
    model.summary()
    fpath = './tensorlog/weights.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5'
    cp_cb = ModelCheckpoint(filepath = fpath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
    es_cb = EarlyStopping(monitor='val_loss', patience=2, verbose=1, mode='auto')
    tb_cb = TensorBoard(log_dir="./tensorlog", histogram_freq=1)
    model.compile(loss="mean_squared_error", optimizer="rmsprop",  metrics=['accuracy'])  
    model.fit(X_train, y_train, batch_size=600, nb_epoch=10, validation_split=0.05, verbose=1, callbacks=[cp_cb, es_cb, tb_cb]) 
json_string = model.to_json()
open(os.path.join(f_model,'./tensorlog/rnn_model.json'), 'w').write(json_string)
KTF.set_session(old_session)

各種コールバックを作成し,fit関数のcallbacksにまとめて登録しています.
上記のように,tensorboardのコールバックを用いる際は,tf.Session('')で明示的にtensorflowのセッションを作成し,set_session関数によりkeras側に登録する必要があります.

以上の処理によって,./tensorlogディレクトリ内に, weights.01-0.02-0.00.hdf5といったパラメータログ等が保存されます.

おわりに

今回はKerasのコールバック関数をそれぞれ紹介し,実際に使用してみました.呼び出すタイミング等をさらに細かく制御するために,自前で関数を規定する方法もありますが,今回は割愛します.

197
191
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
197
191

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?