チェックポイントとは
TensorFlow の訓練中にモデルの重みや状態を保存・復元する仕組みを「チェックポイント」と呼びます。これを活用することで、長時間かかる訓練の途中で中断・再開したり、最良モデルだけを切り出して検証に使ったり、学習曲線を比較したり、といった運用が可能になります。
チェックポイントを作成する意味
-
訓練の中断・再開
長時間の訓練を途中で止めても、最後に保存した重みから再開でき、無駄な計算を省けます。
-
過学習の抑制
検証データで精度が最も良かったエポックのモデルだけを残し、訓練後に復元して使うことで、過学習モデルの利用を避けられます。
-
実験の再現性
同じ重みファイルを用いれば、他環境や後日になってもまったく同じ性能を再現できます。
-
デプロイ準備
訓練完了後、最良モデルを切り出してサーバやアプリに組み込む際に便利です。
具体的な実装
以下では、MNIST データセットに対してシンプルな全結合モデルを訓練し、チェックポイントを作成する流れを示します。
import os
import tensorflow as tf
from tensorflow import keras
# (省略:データ読み込み/前処理、create_model 定義など)
model = create_model()
まず、保存先のパスを決めてコールバックを作成します。
checkpoint_path = "training_1/cp.ckpt.weights.h5"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
verbose=1
)
-
save_weights_only=True
:モデル構造ではなく「重みだけ」を保存 -
verbose=1
:保存時にログを出力
あとは fit
メソッドの callbacks
引数に渡すだけです。
model.fit(
train_images, train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]
)
訓練後、training_1/
フォルダ内に cp.ckpt.weights.h5
が生成されます。
チェックポイントを使ってモデルを復元する
保存した重みを読み込んで評価・再訓練を行う手順です。
# 新しいモデルインスタンスを作成
model = create_model()
# 未訓練モデルの精度を評価
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Untrained model accuracy: {acc:.2%}")
# チェックポイントから重みを読み込み
model.load_weights(checkpoint_path)
# 復元モデルの精度を再評価
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Restored model accuracy: {acc:.2%}")
これで、訓練途中あるいは訓練完了時のモデル性能を簡単に比較できます。
エポックごとにチェックポイントを作成する
もっと細かく、たとえば「5 エポックごとに保存したい」「ファイル名にエポック番号を埋め込みたい」といった場合は、以下のように設定します。
import math
# ファイル名にエポック番号を含める
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt.weights.h5"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
batch_size = 32
n_batches = math.ceil(len(train_images) / batch_size)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_freq=5 * n_batches, # 5 エポックごとに保存
verbose=1
)
model = create_model()
# エポック 0 の重みを明示的に保存
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(
train_images, train_labels,
epochs=50,
batch_size=batch_size,
validation_data=(test_images, test_labels),
callbacks=[cp_callback],
verbose=0
)
-
save_freq
にバッチ数単位を渡すことで、任意のエポック間隔に対応 -
{epoch:04d}
で出力ファイル名に 4 桁ゼロ埋めのエポック番号を埋め込み
保存された複数ファイルの中から最新のものを読み込むには、
import glob
# ディレクトリ中の *.weights.h5 を取得し、最終更新日時でソート
files = glob.glob("training_2/*.weights.h5")
latest = max(files, key=os.path.getctime)
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Latest restored model accuracy: {acc:.2%}")
のようにすれば OK です。
まとめ
- チェックポイント を作ると、訓練の中断・再開や過学習抑制、最良モデルの切り出しが容易に。
-
ModelCheckpoint
コールバックで 重みのみ/構造+重み を定期的に保存可能。 -
save_freq
やファイル名フォーマットを工夫すると、任意の間隔・名前で複数ファイルを管理できる。 - 保存ファイルは
load_weights
で簡単に復元でき、再評価や継続訓練に利用できる。
これらを活用して、再現性と効率を高めた機械学習開発を進めましょう。
参考
Tensorflow公式チュートリアル