PytorchはMultiGPUで学習・推論するときの便利な機能として torch.nn.DataParallel
というモジュールを準備してくれている。
すごく楽に使えるのですが、一方でこれで学習したモデルを保存・利用する際にちょっと気をつけないと行けない点があるので忘れないうちにメモしておく。
なぜ起こるのか?
ここで、 self.module = module
が入っているので、Model構成のprefixが元のModel Classと変わるため。
ミスると下記のようなエラーが出る。
KeyError: 'unexpected key "module.xx.xx.weight" in state_dict'
解決策1. 保存する前に気をつける
こんな感じで nn.DataParallel()
のクラスモジュールから元のファイルをもって来て保存する。
torch.save(model.module.state_dict(), output_model_path)
解決策2. 保存したものを書き換える
もうすでに保存してしまって悲しい思いをしてしまってる方は、諦めて無理やりparameterのkeyを書き換える。
def fix_model_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:] # remove 'module.' of dataparallel
new_state_dict[name] = v
return new_state_dict
# load it
state_dict = torch.load('my_model.pth')
model = MyModel()
model.load_state_dict(fix_model_state_dict(state_dict))