モデルをリファクタリングして構造を変えたあと,前のckptから重みをロードしたいときのためのメモ
.ckptの更新
入力: ~.ckpt
ファイルのパス
出力:更新後の重みを~.converted.ckpt
に保存してパスを返す
mapperにて,変更前のパラメータと変更後のパラメータをdictで持っておく.
※ .ckptファイルはstate_dict
をキーとしてsate_dictを持つと仮定
# 更新したckpt_pathを返す関数
def create_converted_ckpt_file(checkpoint_path: str) -> str:
mapper = {
"変更したいパラメータの部分str": "変更後の部分str"
"model.linear":"model.decoder.linear",
"model.transformer":"model.decoder.transformer"
}
dest_ckpt_path = checkpoint_path.replace(".ckpt", ".converted.ckpt")
if Path(dest_ckpt_path).exists():
return dest_ckpt_path
model_ckpt = torch.load(checkpoint_path)
model_dict = model_ckpt['state_dict']
new_model_dict = _map_weights(model_dict, mapper=mapper)
model_ckpt['state_dict'] = new_model_dict
torch.save(model_ckpt, dest_ckpt_path)
return dest_ckpt_path
# state_dictを更新する関数
def _map_weights(state_dict, mapper: dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# kにmapperのkeyが含まれている場合、その部分をmapperのvalueに置換する
for old, new in mapper.items():
if old in k and new not in k:
k = k.replace(old, mapper[old])
new_state_dict[k] = v
return new_state_dict
.ckptのロード
まず,checkpoint_fileを上記の関数に渡して,更新したckptのパスを取得.そして,これをロードする.
checkpoint_file_path = "./hoge/hoge.ckpt"
ckpt_path = create_converted_ckpt_file(checkpoint_file_path)
model = MyModel.load_state_dict(torch.load(ckpt_path)['state_dict'])
# pytorch-lighitningの場合
model = MyLighitningModel.load_from_checkpoint(ckpt_path)