LoginSignup
35
24

More than 3 years have passed since last update.

Pytorchのnn.DataParallelを使ったモデルを保存するとloadするときにエラーになる問題

Posted at

PytorchはMultiGPUで学習・推論するときの便利な機能として torch.nn.DataParallelというモジュールを準備してくれている。
すごく楽に使えるのですが、一方でこれで学習したモデルを保存・利用する際にちょっと気をつけないと行けない点があるので忘れないうちにメモしておく。

なぜ起こるのか?

ここで、 self.module = module が入っているので、Model構成のprefixが元のModel Classと変わるため。
ミスると下記のようなエラーが出る。

KeyError: 'unexpected key "module.xx.xx.weight" in state_dict'

解決策1. 保存する前に気をつける

こんな感じで nn.DataParallel() のクラスモジュールから元のファイルをもって来て保存する。

torch.save(model.module.state_dict(), output_model_path)

解決策2. 保存したものを書き換える

もうすでに保存してしまって悲しい思いをしてしまってる方は、諦めて無理やりparameterのkeyを書き換える。

*下記のリンクを参考
https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686

def fix_model_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith('module.'):
            name = name[7:]  # remove 'module.' of dataparallel
        new_state_dict[name] = v
    return new_state_dict

# load it
state_dict = torch.load('my_model.pth')
model = MyModel()
model.load_state_dict(fix_model_state_dict(state_dict))
35
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
24