モチベーション
- LLMのファインチューニングは過学習するためにEarlyStoppingをかける事が定番。
- チェックポイントの保存が必須となるが、標準はストレージへの保存となる。BERT系だと1GB弱もある。
- 交差検証なら気にならなかったがハイパーパラメータサーチとなるとSSDとはいえI/Oにかかる時間が無視できず、モデルの保存も必須ではない。
- 100サンプルの少ないデータセットだと学習している時間よりI/Oに時間がかかり、GPUの使用率が低すぎるという事態に…(ほぼ定常状態の発熱と変わらない)
- なので、計算機サーバのメモリに余裕があるのでインメモリ化して高速化
ライブラリ
- torch2.0.1+cu118
- pytorch-lightning 2.0.9.post0
インメモリ化の定義
- ModelCheckpointの保存先をBytesIOに渡すようにオーバーライドするだけ
import io
class InMemoryModelCheckpoint(ModelCheckpoint):
def _save_model(self, trainer, filepath):
"""モデルのstate_dictをメモリに保存"""
binary_stream = io.BytesIO()
trainer.save_checkpoint(filepath=binary_stream)
通常のModelCheckpointと同様にコンストラクタ引数を設定
checkpoint_callback = InMemoryModelCheckpoint(
monitor='val_loss', # LightningModuleのself.logで保存しているパラメータ
mode='min',
save_top_k=1,
save_last=False,
)
trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)
ロード
通常と同じ
ckpt = torch.load(checkpoint_callback.best_model_path)
model.load_state_dict(ckpt['state_dict'])