LoginSignup
23
9

More than 3 years have passed since last update.

nn.DataParallel(model)で並列化して学習させた後うっかりして普通に保存した場合のロード方法

Posted at

なにをしたいのか

  • nn.DataParallel(model)した後、torch.save(model.state_dict(), PATH)して、RuntimeError: Error(s) in loading state_dictと言われた場合の処理
  • 何度もやってコードに追記する部分のメモ

対処

予防

model.moduleとして、保存するときに1枚のGPUにおけるstate_dictを拾うようにするのが一番手っ取り早いです。

保存
torch.save(model.module.state_dict(), PATH)

事後

もう保存してしまった場合は、state_dictについているkeyの"module"を消せばひとまず動きます

keyの文字列からmodule.を消す
from collections import OrderedDict
def fix_key(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    return new_state_dict

ざっくりとした説明

適当にtorchvision.modelsからモデルを持ってきて、複数GPUで並列計算させると

モデルのトレーニング
import torch.nn as nn
import torchvision
model = torchvision.models.alexnet(pretrained=True)

model = nn.DataParallel(model)

# 中略:トレーニング

ここでモデルを保存しようとすると

保存
import torch
torch.save(model.state_dict(), PATH)

で保存ができますが、model.state_dict()は並列化した結果のモデルのstate_dictなので

モデルの再ロード
model = torchvision.models.alexnet()
model.load_state_dict(torch.load(PATH))

とすると

エラー内容
RuntimeError: Error(s) in loading state_dict ...

として、並列化したモデルの重みを1枚のGPU用に用意したモデルに載せようとするためロードができません

もう少し詳しい説明

モデルのパラメータは、OrderedDictで保存されます

OrderedDictを見る
model = torchvision.models.alexnet()
print(model.state_dict().keys())

# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

このkeyはnn.DataParallelすると、頭に"module"がついてGPUにコピーされます。

DataParallel後のstate_dictのkey
model = nn.DataParallel(model)
print(model.state_dict().keys())

# odict_keys(['module.features.0.weight', 'module.features.0.bias', 'module.features.3.weight', 'module.features.3.bias', 'module.features.6.weight', 'module.features.6.bias', 'module.features.8.weight', 'module.features.8.bias', 'module.features.10.weight', 'module.features.10.bias', 'module.classifier.1.weight', 'module.classifier.1.bias', 'module.classifier.4.weight', 'module.classifier.4.bias', 'module.classifier.6.weight', 'module.classifier.6.bias'])

model.moduleは元通りの値が保存されている

model.module
print(model.module.state_dict().keys())

# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

これは、DataParallelしたmodelが、実は元のmodelをDataParallelのmoduleとして保持する仕様だから。

DataParallel前のモデル
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

DataParallel後のモデル
DataParallel(
  (module): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=False)
      (1): Linear(in_features=9216, out_features=4096, bias=True)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=4096, out_features=4096, bias=True)
      (5): ReLU(inplace=True)
      (6): Linear(in_features=4096, out_features=1000, bias=True)
    )
  )
)

23
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
23
9