Kerasで学習途中の最適な重みを保存しておくには、ModelCheckpointコールバックを使うのが定石です。
Tensorflow チュートリアル - 訓練中にチェックポイントを保存する
しかし、保存先がファイルしか選択できません。ファイルをやたり読み書きするのは気が引けますし、なによりも遅いです。これをメモリに保存しておいて取り出す方法がありそうで無いです。いろいろ探したところ、英文の質問応答サイトに方法が書かれていたので、実際に試してブラッシュアップした内容を共有します。
以下が作成したメモリに重みを書き出すコードです。ModelCheckpointは使っていませんが、同等の処理を記述できメモリ保存できるコールバックを新たに作成しています。ディスクに書き出すのと比較して、体感で5倍程度の学習速度になりました(もちろんエポックの計算量に依存します)。
# メモリに重みを保存するコールバック
class SaveBestWeightsInMemory(tf.keras.callbacks.Callback):
def __init__(self):
super().__init__()
self.best = np.Inf
self.best_epoch = 0
def on_epoch_end(self, epoch, logs=None):
if np.less(val_loss := logs.get('val_loss'), self.best):
print(f'{epoch=}: val_loss improved from {self.best:.4f} to {val_loss:.4f}'
', saving model to memory.')
self.best = val_loss
self.best_epoch = epoch
self.best_weights = self.model.get_weights()
# 学習コード本体(モデル作成、コンパイル済)
epochs = 3000
callback = [
sb_cb := SaveBestWeightsInMemory(), # 今回作成のコールバックの呼び出し
LambdaCallback(on_epoch_end=lambda epoch, logs: (
print(f'{epoch=}: ' + ', '.join([f'{k}:{v:.4f}' for k, v in logs.items()]))
if epoch % 10 == 0 else None)), # 10エポックに1度経過表示をするコールバック
EarlyStopping(monitor='val_loss', patience=500), # 通常はアーリーストッピングと組み合わせ
]
result = model.fit(x_train, y_train, epochs=epochs, verbose=0,
callbacks=callback, validation_data=(x_val, y_val))
model.set_weights(sb_cb.best_weights) # メモリ保存した最適な重みを読み出す
なお、上記コードには、Python3.8以上でないと動かない文法が入っています。