LoginSignup
19
11

More than 5 years have passed since last update.

kerasで学習中の狙ったタイミングのモデルを保存する

Last updated at Posted at 2017-12-09

せっかく時間をかけて学習したのに…

  • 間違えてCtrl+Cを押して終了してしまった!
  • 放置しすぎて明らかに過学習してしまった!
  • NaN値が飛び出してモデルが意味のないものになってしまった!
  • すごい良さそうなValidationの値が出てたのに最後に保存されてるモデルはもう既にそれじゃない!

などなど、経験したことはありませんか?
時間も電気代も使ってたくさん計算したのにもう1回やり直し、なんてなったら機械学習のモチベーションは遥か彼方だと思います。

そういう事を避ける意味でも、そうでなくとも学習中の適切なタイミングのモデルを保存していきたいと思いこんなクラスを作りました。

基本的な使い方はModelCheckpointとあまり変わりませんが、これそのままだとあんまり融通が利かないんですよね。

このCustomModelCheckpointは各メトリクスやValidationメトリクスに対して、特定の値を超えた(下回った)タイミングのモデルを指定のファイル名で保存します。複数の値を指定することもでき、その場合はANDで(すべて満たす場合のみ)保存します。しかもファイル名にはメトリクスの数値を含められるのでファイルをパッと見で比較できます。

使い方

howto.py
model.compile(optimizer=<some optimizer>,
              loss=<some loss>,
              metrics=['accuracy', custom_metrics])

checkpoint = CustomModelCheckpoint(
    filepath="./model-{epoch:05d}-{val_acc:.4f}-{val_custom_metrics:.4f}.h5",
    thresholds={
        'acc': 0.8,
        'val_acc': 0.75,
        'val_custom_metrics': 0.5
    })

model.fit(X, y, validation_data=(X_test, y_test), callbacks=[checkpoint])

こんな感じで自作のmetricsにも使えます。基本的なルールとしてメトリクスの名前に val_ を付けるとValidationデータ時に計算した値が該当するようです。(また accuracy のみ acc という書き方ができるっぽい)

上の例の場合はトレーニングデータの精度が0.8以上かつバリデーションデータの精度が0.75以上かつcustom_metricsで指定した値の大きさがバリデーション時に0.5以上を満たしたepochの時のみモデルを「model-00012-0.8231-0.7821-0.6172.h5」みたいな名前で保存します。

イニシャライザに inverse=True を指定すると「以上」ではなく「以下」で判定します。

今のところ以上と以下を分ける程の柔軟さが無いので必要であれば正負をひっくり返したcustom_metricsにすれば良いと思います。(未検証(๑´ڡ`๑)てへぺロ)

19
11
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
11