LoginSignup
4
5

More than 3 years have passed since last update.

PyTorch Lightningで予測時に"Missing key(s) in state_dict: ..." のエラーが出た時の対応メモ

Last updated at Posted at 2021-02-23

はじめに

PyTorch Lightningで保存されたモデルを使って予測をしようとした時に、Missing key(s) in state_dict:...のエラーが出てハマった話

コードとエラー内容

コード

model = GraphConvModel.load_from_checkpoint(args.model)

エラー

    model = GraphConvModel.load_from_checkpoint(args.model)
  File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\site-packages\pytorch_lightning\core\saving.py", line 158, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\sitepackages\pytorch_lightning\core\saving.py", line 204, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "C:\Users\kimisyo\.conda\envs\OpenChem\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for GraphConvModel:
        Missing key(s) in state_dict: "graph_conv_layer_list.1.bias","graph_conv_layer_list.1.self_activation.weight", "graph_conv_layer_list.1.degree_liner_layer_list.0.weight","graph_conv_layer_list.1.degree_liner_layer_list.1.weight","graph_conv_layer_list.1.degree_liner_layer_list.2.weight", "graph_conv_layer_list.1.degree_l

(以下略)

想定される原因

モデル(pl.LightningModuleを継承して作成したもの) 学習時に、隠れ層の数やサイズをコンストクタの引数で与えていたが、load_from_checkpointでは、デフォルト引数でインスタンス化したモデルに対して重みがロードされることが原因と考えられる。

対策

当初は、モデルを構築時と同じ条件でインスタンス化し、そのインスタンス(例えばmodel)に対して以下のようにload_model_checkpointを呼び出せばよいかと思ったが、解消されず。

model.GraphConvModel.load_from_checkpoint(args.model)

色々調べた結果、以下のようにload_from_checkpointの引数に、Modelのインスタンスに与えたい引数を指定したところ解決。

model = GraphConvModel.load_from_checkpoint(args.model,
    atom_features_size=62,
    conv_layer_sizes=[20,20,20],
    fingerprints_size=50,
    mlp_layer_sizes=[100, 1]
)   

今回はPyTorch Lightningに特化して書いたが、PyTorch Lightningを使わないケースでも、load_from_checkpointを使ってモデルをロードする場合、注意が必要と思われる。

参考

https://pytorch-lightning.readthedocs.io/en/0.8.5/weights_loading.html

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5