0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorch lightningのEarlyStoppingを高速化

Posted at

モチベーション

  • LLMのファインチューニングは過学習するためにEarlyStoppingをかける事が定番。
  • チェックポイントの保存が必須となるが、標準はストレージへの保存となる。BERT系だと1GB弱もある。
  • 交差検証なら気にならなかったがハイパーパラメータサーチとなるとSSDとはいえI/Oにかかる時間が無視できず、モデルの保存も必須ではない。
  • 100サンプルの少ないデータセットだと学習している時間よりI/Oに時間がかかり、GPUの使用率が低すぎるという事態に…(ほぼ定常状態の発熱と変わらない)
  • なので、計算機サーバのメモリに余裕があるのでインメモリ化して高速化

ライブラリ

  • torch2.0.1+cu118
  • pytorch-lightning 2.0.9.post0

インメモリ化の定義

  • ModelCheckpointの保存先をBytesIOに渡すようにオーバーライドするだけ
import io
class InMemoryModelCheckpoint(ModelCheckpoint):
    def _save_model(self, trainer, filepath):
        """モデルのstate_dictをメモリに保存"""
        binary_stream = io.BytesIO()
        trainer.save_checkpoint(filepath=binary_stream)

通常のModelCheckpointと同様にコンストラクタ引数を設定

checkpoint_callback = InMemoryModelCheckpoint(
    monitor='val_loss', # LightningModuleのself.logで保存しているパラメータ
    mode='min',
    save_top_k=1,
    save_last=False,
)

trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)

ロード

通常と同じ

ckpt = torch.load(checkpoint_callback.best_model_path)
model.load_state_dict(ckpt['state_dict'])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?