0
0
この記事誰得? 私しか得しないニッチな技術で記事投稿!
Qiita Engineer Festa20242024年7月17日まで開催中!

Pytorchの.ckptファイルの引っ越し.モデル構造を変更したときの重みの移行

Last updated at Posted at 2024-06-27

モデルをリファクタリングして構造を変えたあと,前のckptから重みをロードしたいときのためのメモ

.ckptの更新

入力: ~.ckptファイルのパス
出力:更新後の重みを~.converted.ckptに保存してパスを返す

mapperにて,変更前のパラメータと変更後のパラメータをdictで持っておく.
※ .ckptファイルはstate_dictをキーとしてsate_dictを持つと仮定

# 更新したckpt_pathを返す関数
def create_converted_ckpt_file(checkpoint_path: str) -> str:
    mapper = {
        "変更したいパラメータの部分str": "変更後の部分str"
        "model.linear":"model.decoder.linear",
        "model.transformer":"model.decoder.transformer"
    }
    dest_ckpt_path = checkpoint_path.replace(".ckpt", ".converted.ckpt")
    if Path(dest_ckpt_path).exists():
        return dest_ckpt_path
    model_ckpt = torch.load(checkpoint_path)
    model_dict = model_ckpt['state_dict']
    new_model_dict = _map_weights(model_dict, mapper=mapper)
    model_ckpt['state_dict'] = new_model_dict
    torch.save(model_ckpt, dest_ckpt_path)
    return dest_ckpt_path
# state_dictを更新する関数
def _map_weights(state_dict, mapper: dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # kにmapperのkeyが含まれている場合、その部分をmapperのvalueに置換する
        for old, new in mapper.items():
            if old in k and new not in k:
                k = k.replace(old, mapper[old])
        new_state_dict[k] = v

    return new_state_dict

.ckptのロード

まず,checkpoint_fileを上記の関数に渡して,更新したckptのパスを取得.そして,これをロードする.

checkpoint_file_path = "./hoge/hoge.ckpt"
ckpt_path = create_converted_ckpt_file(checkpoint_file_path)

model = MyModel.load_state_dict(torch.load(ckpt_path)['state_dict'])

# pytorch-lighitningの場合
model = MyLighitningModel.load_from_checkpoint(ckpt_path)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0