lighitningでdeepspeedを用いる方法は下記を参考にしてください.
Deepspeedのcheckpointのロード
deepspeedのチェックポイントはディレクトリになっており, 単純にload_from_checkpoint
ではロードできません.
なので,convert_zero_checkpoint_to_fp32_state_dict
関数でptファイルへと変換してからロードしましょう.
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
ckpt_dir_path = "/hoge/epochX_stepY.ckpt"
dist_path = ckpt_dir_path + "/state_dict.pt"
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_path, dist_path)
MyLighiningModel.load_from_checkpoint(dist_path)