なにをしたいのか
-
nn.DataParallel(model)
した後、torch.save(model.state_dict(), PATH)
して、RuntimeError: Error(s) in loading state_dict
と言われた場合の処理 - 何度もやってコードに追記する部分のメモ
対処
予防
model.module
として、保存するときに1枚のGPUにおけるstate_dictを拾うようにするのが一番手っ取り早いです。
保存
torch.save(model.module.state_dict(), PATH)
事後
もう保存してしまった場合は、state_dictについているkeyの"module"を消せばひとまず動きます
keyの文字列からmodule.を消す
from collections import OrderedDict
def fix_key(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
return new_state_dict
ざっくりとした説明
適当にtorchvision.modelsからモデルを持ってきて、複数GPUで並列計算させると
モデルのトレーニング
import torch.nn as nn
import torchvision
model = torchvision.models.alexnet(pretrained=True)
model = nn.DataParallel(model)
# 中略:トレーニング
ここでモデルを保存しようとすると
保存
import torch
torch.save(model.state_dict(), PATH)
で保存ができますが、model.state_dict()は並列化した結果のモデルのstate_dictなので
モデルの再ロード
model = torchvision.models.alexnet()
model.load_state_dict(torch.load(PATH))
とすると
エラー内容
RuntimeError: Error(s) in loading state_dict ...
として、並列化したモデルの重みを1枚のGPU用に用意したモデルに載せようとするためロードができません
もう少し詳しい説明
モデルのパラメータは、OrderedDictで保存されます
OrderedDictを見る
model = torchvision.models.alexnet()
print(model.state_dict().keys())
# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])
このkeyはnn.DataParallel
すると、頭に"module"がついてGPUにコピーされます。
DataParallel後のstate_dictのkey
model = nn.DataParallel(model)
print(model.state_dict().keys())
# odict_keys(['module.features.0.weight', 'module.features.0.bias', 'module.features.3.weight', 'module.features.3.bias', 'module.features.6.weight', 'module.features.6.bias', 'module.features.8.weight', 'module.features.8.bias', 'module.features.10.weight', 'module.features.10.bias', 'module.classifier.1.weight', 'module.classifier.1.bias', 'module.classifier.4.weight', 'module.classifier.4.bias', 'module.classifier.6.weight', 'module.classifier.6.bias'])
model.moduleは元通りの値が保存されている
model.module
print(model.module.state_dict().keys())
# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])
これは、DataParallelしたmodelが、実は元のmodelをDataParallelのmoduleとして保持する仕様だから。
DataParallel前のモデル
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
DataParallel後のモデル
DataParallel(
(module): AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
)