2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pytorchの.ckptファイルの引っ越し.モデル構造を変更したときの重みの移行

Last updated at Posted at 2024-06-27

モデルをリファクタリングして構造を変えたあと,前のckptから重みをロードしたいときのためのメモ

そもそもstate_dictの構造とは?

(既にご存知の方は読み飛ばしてください.)

state_dictは,モデルのどのパラメータにどんな重みを渡せばいいかという情報を辞書型で持っており,
ckptファイルをロードすることで取り出すことが出来る.

以下にstate_dictの簡単な例を示す.

SampleModel

class SampleModel(nn.Module):
    def __init__():
        self.embeddings = MyEmbeddingModule()
        self.linear = nn.Linear(...)

class MyEmbeddingModule(nn.Module):
    def __init__():
        self.proj = nn.Conv2D(...)

このようなモデル構造の場合に得られるstate_dictは下記のようになる.

state_dict = {
    "embeddings.proj.weight": <torch.Tensor> # パラメータ名: 重み
    "embeddings.proj.bias": <torch.Tensor>
    "linear": <torch.Tensor>
}

つまり,モデル構造を変更したけれど以前のckptを使いたいときには,適切にstate_dictのkeyを変化させる必要がある.

実装

全体像のイメージは大体こんな感じです.
古いckptから新しいckptを作成してロードします.

def create_converted_ckpt_file(checkpoint_path: str) -> str:
    mapper = {
        "変更したいパラメータ名": "変更後のパラメータ名"
        ...
    }
    ...
    state_dict = checkpoint_pathからstate_dictを取り出す
    new_state_dict = _map_weights(state_dict, mapper) # keyを書き換える
    new_ckpt_path = new_state_dictを元に新しいckptを作成
    return new_ckpt_path
    
old_ckpt_path = "./hoge/hoge.ckpt" # 古いckptファイル
new_ckpt_path = create_converted_ckpt_file(checkpoint_file_path)
model.load_from_checkpoint(new_ckpt_path)

state_dictの書き換えを_map_weights, create_converted_ckpt_file関数で行っており,
次にその内部について説明します.

state_dictの書き換え

_map_weights 関数

変更前state_dictとパラメータ名のmapperを受け取り,変更後のstate_dictを返す.
後述のcreate_converted_ckpt_file 関数内で利用.

  • 入力:
    • state_dict : 変更前のckptからロードしたstate_dict
      • (key: value) = (変更前パラメータ名 : 重み)
    • mapper : パラメータ変更のdict
      • (key: value) = (変更前パラメータ名 : 変更後パラメータ名)
  • 出力:
    • new_state_dict: パラメータ名を変更したstate_dict
      • (key: value) = (変更後パラメータ名 : 重み)
# state_dictを更新する関数
def _map_weights(state_dict, mapper: dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # kにmapperのkeyが含まれている場合、その部分をmapperのvalueに置換する
        for old, new in mapper.items():
            if old in k and new not in k:
                k = k.replace(old, mapper[old])
        new_state_dict[k] = v

    return new_state_dict

.ckptの更新

create_converted_ckpt_file 関数

変更前のckptパスを受け取って,移行したモデル構造に合ったckptファイルを作成,そのパスを返す.
(変化したパラメータ名の情報はdict(コードではmapper変数)で保持しておく.)

  • Input
    • ~.ckptファイルのパス
    • ※ .ckptファイルはstate_dictをキーとしてsate_dictを持つと仮定
  • Output
    • 更新後の重みを~.converted.ckptに保存してパスを返す
# 更新したckpt_pathを返す関数
def create_converted_ckpt_file(checkpoint_path: str) -> str:
    mapper = {
        "変更したいパラメータの部分str": "変更後の部分str"
        "model.linear":"model.decoder.linear",
        "model.transformer":"model.decoder.transformer"
    }
    dest_ckpt_path = checkpoint_path.replace(".ckpt", ".converted.ckpt")
    if Path(dest_ckpt_path).exists():
        return dest_ckpt_path
    model_ckpt = torch.load(checkpoint_path)
    model_dict = model_ckpt['state_dict']
    new_model_dict = _map_weights(model_dict, mapper=mapper)
    model_ckpt['state_dict'] = new_model_dict
    torch.save(model_ckpt, dest_ckpt_path)
    return dest_ckpt_path

.ckptのロード

checkpoint_fileを上記の関数に渡して,更新したckptのパスを取得.そして,これをロードする.

checkpoint_file_path = "./hoge/hoge.ckpt"
ckpt_path = create_converted_ckpt_file(checkpoint_file_path)

model = MyModel.load_state_dict(torch.load(ckpt_path)['state_dict'])

# pytorch-lighitningの場合
model = MyLighitningModel.load_from_checkpoint(ckpt_path)
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?