モデルをリファクタリングして構造を変えたあと,前のckptから重みをロードしたいときのためのメモ
そもそもstate_dictの構造とは?
(既にご存知の方は読み飛ばしてください.)
state_dictは,モデルのどのパラメータにどんな重みを渡せばいいかという情報を辞書型で持っており,
ckptファイルをロードすることで取り出すことが出来る.
以下にstate_dictの簡単な例を示す.
SampleModel
class SampleModel(nn.Module):
def __init__():
self.embeddings = MyEmbeddingModule()
self.linear = nn.Linear(...)
class MyEmbeddingModule(nn.Module):
def __init__():
self.proj = nn.Conv2D(...)
このようなモデル構造の場合に得られるstate_dictは下記のようになる.
state_dict = {
"embeddings.proj.weight": <torch.Tensor> # パラメータ名: 重み
"embeddings.proj.bias": <torch.Tensor>
"linear": <torch.Tensor>
}
つまり,モデル構造を変更したけれど以前のckptを使いたいときには,適切にstate_dictのkeyを変化させる必要がある.
実装
全体像のイメージは大体こんな感じです.
古いckptから新しいckptを作成してロードします.
def create_converted_ckpt_file(checkpoint_path: str) -> str:
mapper = {
"変更したいパラメータ名": "変更後のパラメータ名"
...
}
...
state_dict = checkpoint_pathからstate_dictを取り出す
new_state_dict = _map_weights(state_dict, mapper) # keyを書き換える
new_ckpt_path = new_state_dictを元に新しいckptを作成
return new_ckpt_path
old_ckpt_path = "./hoge/hoge.ckpt" # 古いckptファイル
new_ckpt_path = create_converted_ckpt_file(checkpoint_file_path)
model.load_from_checkpoint(new_ckpt_path)
state_dictの書き換えを_map_weights
, create_converted_ckpt_file
関数で行っており,
次にその内部について説明します.
state_dictの書き換え
_map_weights
関数
変更前state_dictとパラメータ名のmapperを受け取り,変更後のstate_dictを返す.
後述のcreate_converted_ckpt_file
関数内で利用.
- 入力:
- state_dict : 変更前のckptからロードしたstate_dict
- (key: value) = (変更前パラメータ名 : 重み)
- mapper : パラメータ変更のdict
- (key: value) = (変更前パラメータ名 : 変更後パラメータ名)
- state_dict : 変更前のckptからロードしたstate_dict
- 出力:
- new_state_dict: パラメータ名を変更したstate_dict
- (key: value) = (変更後パラメータ名 : 重み)
- new_state_dict: パラメータ名を変更したstate_dict
# state_dictを更新する関数
def _map_weights(state_dict, mapper: dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# kにmapperのkeyが含まれている場合、その部分をmapperのvalueに置換する
for old, new in mapper.items():
if old in k and new not in k:
k = k.replace(old, mapper[old])
new_state_dict[k] = v
return new_state_dict
.ckptの更新
create_converted_ckpt_file
関数
変更前のckptパスを受け取って,移行したモデル構造に合ったckptファイルを作成,そのパスを返す.
(変化したパラメータ名の情報はdict(コードではmapper変数)で保持しておく.)
- Input
-
~.ckpt
ファイルのパス - ※ .ckptファイルは
state_dict
をキーとしてsate_dictを持つと仮定
-
- Output
- 更新後の重みを
~.converted.ckpt
に保存してパスを返す
- 更新後の重みを
# 更新したckpt_pathを返す関数
def create_converted_ckpt_file(checkpoint_path: str) -> str:
mapper = {
"変更したいパラメータの部分str": "変更後の部分str"
"model.linear":"model.decoder.linear",
"model.transformer":"model.decoder.transformer"
}
dest_ckpt_path = checkpoint_path.replace(".ckpt", ".converted.ckpt")
if Path(dest_ckpt_path).exists():
return dest_ckpt_path
model_ckpt = torch.load(checkpoint_path)
model_dict = model_ckpt['state_dict']
new_model_dict = _map_weights(model_dict, mapper=mapper)
model_ckpt['state_dict'] = new_model_dict
torch.save(model_ckpt, dest_ckpt_path)
return dest_ckpt_path
.ckptのロード
checkpoint_fileを上記の関数に渡して,更新したckptのパスを取得.そして,これをロードする.
checkpoint_file_path = "./hoge/hoge.ckpt"
ckpt_path = create_converted_ckpt_file(checkpoint_file_path)
model = MyModel.load_state_dict(torch.load(ckpt_path)['state_dict'])
# pytorch-lighitningの場合
model = MyLighitningModel.load_from_checkpoint(ckpt_path)