はじめに
PyTorch Lightningで保存されたモデルを使って予測をしようとした時に、Missing key(s) in state_dict:...
のエラーが出てハマった話
コードとエラー内容
コード
model = GraphConvModel.load_from_checkpoint(args.model)
エラー
model = GraphConvModel.load_from_checkpoint(args.model)
File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\site-packages\pytorch_lightning\core\saving.py", line 158, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\sitepackages\pytorch_lightning\core\saving.py", line 204, in _load_model_state
model.load_state_dict(checkpoint['state_dict'], strict=strict)
File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for GraphConvModel:
Missing key(s) in state_dict: "graph_conv_layer_list.1.bias","graph_conv_layer_list.1.self_activation.weight", "graph_conv_layer_list.1.degree_liner_layer_list.0.weight","graph_conv_layer_list.1.degree_liner_layer_list.1.weight","graph_conv_layer_list.1.degree_liner_layer_list.2.weight", "graph_conv_layer_list.1.degree_l
(以下略)
想定される原因
モデル(pl.LightningModuleを継承して作成したもの) 学習時に、隠れ層の数やサイズをコンストクタの引数で与えていたが、load_from_checkpoint
では、デフォルト引数でインスタンス化したモデルに対して重みがロードされることが原因と考えられる。
対策
当初は、モデルを構築時と同じ条件でインスタンス化し、そのインスタンス(例えばmodel)に対して以下のようにload_model_checkpointを呼び出せばよいかと思ったが、解消されず。
model.GraphConvModel.load_from_checkpoint(args.model)
色々調べた結果、以下のようにload_from_checkpointの引数に、Modelのインスタンスに与えたい引数を指定したところ解決。
model = GraphConvModel.load_from_checkpoint(args.model,
atom_features_size=62,
conv_layer_sizes=[20,20,20],
fingerprints_size=50,
mlp_layer_sizes=[100, 1]
)
今回はPyTorch Lightningに特化して書いたが、PyTorch Lightningを使わないケースでも、load_from_checkpointを使ってモデルをロードする場合、注意が必要と思われる。
参考