LoginSignup
8

More than 1 year has passed since last update.

posted at

nn.DataParallel(model)で並列化して学習させた後うっかりして普通に保存した場合のロード方法

なにをしたいのか

  • nn.DataParallel(model)した後、torch.save(model.state_dict(), PATH)して、RuntimeError: Error(s) in loading state_dictと言われた場合の処理
  • 何度もやってコードに追記する部分のメモ

対処

予防

model.moduleとして、保存するときに1枚のGPUにおけるstate_dictを拾うようにするのが一番手っ取り早いです。

保存
torch.save(model.module.state_dict(), PATH)

事後

もう保存してしまった場合は、state_dictについているkeyの"module"を消せばひとまず動きます

keyの文字列からmodule.を消す
from collections import OrderedDict
def fix_key(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    return new_state_dict

ざっくりとした説明

適当にtorchvision.modelsからモデルを持ってきて、複数GPUで並列計算させると

モデルのトレーニング
import torch.nn as nn
import torchvision
model = torchvision.models.alexnet(pretrained=True)

model = nn.DataParallel(model)

# 中略:トレーニング

ここでモデルを保存しようとすると

保存
import torch
torch.save(model.state_dict(), PATH)

で保存ができますが、model.state_dict()は並列化した結果のモデルのstate_dictなので

モデルの再ロード
model = torchvision.models.alexnet()
model.load_state_dict(torch.load(PATH))

とすると

エラー内容
RuntimeError: Error(s) in loading state_dict ...

として、並列化したモデルの重みを1枚のGPU用に用意したモデルに載せようとするためロードができません

もう少し詳しい説明

モデルのパラメータは、OrderedDictで保存されます

OrderedDictを見る
model = torchvision.models.alexnet()
print(model.state_dict().keys())

# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

このkeyはnn.DataParallelすると、頭に"module"がついてGPUにコピーされます。

DataParallel後のstate_dictのkey
model = nn.DataParallel(model)
print(model.state_dict().keys())

# odict_keys(['module.features.0.weight', 'module.features.0.bias', 'module.features.3.weight', 'module.features.3.bias', 'module.features.6.weight', 'module.features.6.bias', 'module.features.8.weight', 'module.features.8.bias', 'module.features.10.weight', 'module.features.10.bias', 'module.classifier.1.weight', 'module.classifier.1.bias', 'module.classifier.4.weight', 'module.classifier.4.bias', 'module.classifier.6.weight', 'module.classifier.6.bias'])

model.moduleは元通りの値が保存されている

model.module
print(model.module.state_dict().keys())

# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

これは、DataParallelしたmodelが、実は元のmodelをDataParallelのmoduleとして保持する仕様だから。

DataParallel前のモデル
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

DataParallel後のモデル
DataParallel(
  (module): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=False)
      (1): Linear(in_features=9216, out_features=4096, bias=True)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=4096, out_features=4096, bias=True)
      (5): ReLU(inplace=True)
      (6): Linear(in_features=4096, out_features=1000, bias=True)
    )
  )
)

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
8