0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【TensorFlow】チェックポイント活用ガイド

Posted at

チェックポイントとは

TensorFlow の訓練中にモデルの重みや状態を保存・復元する仕組みを「チェックポイント」と呼びます。これを活用することで、長時間かかる訓練の途中で中断・再開したり、最良モデルだけを切り出して検証に使ったり、学習曲線を比較したり、といった運用が可能になります。

チェックポイントを作成する意味

  • 訓練の中断・再開

    長時間の訓練を途中で止めても、最後に保存した重みから再開でき、無駄な計算を省けます。

  • 過学習の抑制

    検証データで精度が最も良かったエポックのモデルだけを残し、訓練後に復元して使うことで、過学習モデルの利用を避けられます。

  • 実験の再現性

    同じ重みファイルを用いれば、他環境や後日になってもまったく同じ性能を再現できます。

  • デプロイ準備

    訓練完了後、最良モデルを切り出してサーバやアプリに組み込む際に便利です。

具体的な実装

以下では、MNIST データセットに対してシンプルな全結合モデルを訓練し、チェックポイントを作成する流れを示します。

import os
import tensorflow as tf
from tensorflow import keras
# (省略:データ読み込み/前処理、create_model 定義など)

model = create_model()

まず、保存先のパスを決めてコールバックを作成します。

checkpoint_path = "training_1/cp.ckpt.weights.h5"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    verbose=1
)

  • save_weights_only=True:モデル構造ではなく「重みだけ」を保存
  • verbose=1:保存時にログを出力

あとは fit メソッドの callbacks 引数に渡すだけです。

model.fit(
    train_images, train_labels,
    epochs=10,
    validation_data=(test_images, test_labels),
    callbacks=[cp_callback]
)

訓練後、training_1/ フォルダ内に cp.ckpt.weights.h5 が生成されます。

チェックポイントを使ってモデルを復元する

保存した重みを読み込んで評価・再訓練を行う手順です。

# 新しいモデルインスタンスを作成
model = create_model()

# 未訓練モデルの精度を評価
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Untrained model accuracy: {acc:.2%}")

# チェックポイントから重みを読み込み
model.load_weights(checkpoint_path)

# 復元モデルの精度を再評価
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Restored model accuracy: {acc:.2%}")

これで、訓練途中あるいは訓練完了時のモデル性能を簡単に比較できます。

エポックごとにチェックポイントを作成する

もっと細かく、たとえば「5 エポックごとに保存したい」「ファイル名にエポック番号を埋め込みたい」といった場合は、以下のように設定します。

import math

# ファイル名にエポック番号を含める
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt.weights.h5"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

batch_size = 32
n_batches = math.ceil(len(train_images) / batch_size)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_freq=5 * n_batches,   # 5 エポックごとに保存
    verbose=1
)

model = create_model()
# エポック 0 の重みを明示的に保存
model.save_weights(checkpoint_path.format(epoch=0))

model.fit(
    train_images, train_labels,
    epochs=50,
    batch_size=batch_size,
    validation_data=(test_images, test_labels),
    callbacks=[cp_callback],
    verbose=0
)

  • save_freq にバッチ数単位を渡すことで、任意のエポック間隔に対応
  • {epoch:04d} で出力ファイル名に 4 桁ゼロ埋めのエポック番号を埋め込み

保存された複数ファイルの中から最新のものを読み込むには、

import glob

# ディレクトリ中の *.weights.h5 を取得し、最終更新日時でソート
files = glob.glob("training_2/*.weights.h5")
latest = max(files, key=os.path.getctime)

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Latest restored model accuracy: {acc:.2%}")

のようにすれば OK です。

まとめ

  • チェックポイント を作ると、訓練の中断・再開や過学習抑制、最良モデルの切り出しが容易に。
  • ModelCheckpoint コールバックで 重みのみ構造+重み を定期的に保存可能。
  • save_freq やファイル名フォーマットを工夫すると、任意の間隔・名前で複数ファイルを管理できる。
  • 保存ファイルは load_weights で簡単に復元でき、再評価や継続訓練に利用できる。

これらを活用して、再現性と効率を高めた機械学習開発を進めましょう。

参考

Tensorflow公式チュートリアル

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?