5
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ModelCheckPointのepoch数を変更する

Last updated at Posted at 2018-09-06

#開発環境
jupyterとDockerを使って、Kerasで学習してます。

#問題だったこと

モデルをロードしてさらに学習させたいけど、エポック数が0からになって上手く保存できない!
100epoch学習したあと追加で100epoch学習させたい!
という時にどうしたら良いかわからずハマってしまったのでその解決策

#保存されるファイル名を変化させる
さて、早速本題です.
どうすればファイル名の{epoch:02d}部分を変更させることができるのか.
答えは簡単!
model.fitメソッドに引数として 'initial_epoch' を追加すれば良いです.

model.fit(train_X, train_y, batch_size=batch_size, initial_epoch = 100, nb_epoch=nb_epoch,verbose=0, validation_data=(test_X, test_y),callbacks=[cp_cb])

これだけで学習が101epoch目から始まるようになります.
モデルをloadしなくなった時にはinitial_epochを0に戻すように気をつけましょう.

#コールバックとは
学習と一緒に使うと効果的なのが、ModelCheckPointやEarlyStoppingなどのコールバックと呼ばれる機能です.
まずはじめに、コールバックについて説明します.

コールバックは訓練中で適用される関数集合です.訓練中にモデル内部の状態と統計量を可視化する際に,コールバックを使います.SequentialとModelクラスの.fit()メソッドに(キーワード引数callbacksとして)コールバックのリストを渡すことができます.コールバックに関連するメソッドは,訓練の各段階で呼び出されます.

つまりModelCheckPointなどのコールバックはModelクラスの.fitメソッドとセットで使います!
fitメソッドについては後ほど触れるとして、早速ModelCheckPointについて詳しく見ていきましょう.

#ModelCheckPointとは

  • Kerasで用意されているコールバックの一つ
  • 各エポック終了後にモデルを保存してくれる

##使い方
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

##引数

  • filepath: 文字列,モデルファイルを保存するパス.
  • monitor: 監視する値.
  • verbose: 冗長モード, 0 または 1.
  • save_best_only: save_best_only=Trueの場合,監視しているデータによって最新の最良モデルが上書きされません.
  • mode: {auto, min, max}の内の一つが選択されます.save_best_only=Trueならば,現在保存されているファイルを上書きするかは,監視されている値の最大化か最小化によって決定されます.val_accの場合,この引数はmaxとなり,val_lossの場合はminになります.autoモードでは,この傾向は自動的に監視されている値から推定します.
  • save_weights_only: Trueなら,モデルの重みが保存されます (model.save_weights(filepath)),そうでないなら,モデルの全体が保存されます (model.save(filepath)).
  • period: チェックポイント間の間隔(エポック数).

(一部抜粋)詳しく見たい方はkeras Documentationから見れます

使うときはこんな感じ

fpath = './tensorlog/weights.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5'
cp_cb = ModelCheckpoint(filepath = fpath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto')

model.compile(loss='mse', optimizer = adam)

model.fit(train_X, train_y, batch_size=batch_size, nb_epoch=nb_epoch,verbose=0, validation_data=(test_X, test_y),callbacks=[cp_cb])

#model.fit
最後にmodel.fitメソッドについても復習しておきましょう.
##使い方
model.fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)
##引数

  • x: モデルが単一の入力を持つ場合は訓練データのNumpy配列,もしくはモデルが複数の入力を持つ場合はNumpy配列のリスト.
  • y: モデルが単一の入力を持つ場合は教師(targets)データのNumpy配列,もしくはモデルが複数の出力を持つ場合はNumpy配列のリスト.
  • batch_size: 整数またはNone.勾配更新毎のサンプル数を示す整数.指定しなければbatch_sizeはデフォルトで32になります.
  • epochs: 整数.訓練データ配列の反復回数を示す整数.エポックは,提供されるxおよびyデータ全体の反復です. initial_epochと組み合わせると,epochsは"最終エポック"として理解されることに注意してください.このモデルはepochsで与えられた反復回数だの訓練をするわけではなく,単にepochsという指標に試行が達するまで訓練します.
  • initial_epoch: 整数.訓練を開始するエポック(前回の学習を再開するのに便利です).
  • callbacks: keras.callbacks.Callbackインスタンスのリスト.訓練時に呼ばれるコールバックのリスト.詳細はcallbacksを参照.
  • steps_per_epoch: 整数またはNone.終了した1エポックを宣言して次のエポックを始めるまでのステップ数の合計(サンプルのバッチ).TensorFlowのデータテンソルのような入力テンソルを使用して訓練する場合,デフォルトのNoneはデータセットのサンプル数をバッチサイズで割ったものに等しくなります.それが決定できない場合は1になります.

(一部抜粋)詳しく見たい方はkeras Documentationから見れます

5
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?