前置き
pytorch_lightningを使って学習したモデルをload_state_dictを使って読み込もうとしたら"Missing key(s) in state_dict..."というエラーが出ました。
今回はこのエラーを解消する手順を説明します。
モデルの保存
モデルの学習と保存について説明します。
まずINTRODUCTION TO PYTORCH LIGHTNINGに書いてあるコードをコピペして実行します。
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.models import efficientnet_b0
from torchvision.datasets import MNIST
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
if __name__ == "__main__":
# Init our model
mnist_model = MNISTModel()
# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
# Initialize a trainer
trainer = Trainer(
gpus=AVAIL_GPUS,
max_epochs=3,
progress_bar_refresh_rate=20,
)
# Train the model
trainer.fit(mnist_model, train_loader)
するとlightning_logsというディレクトリができて、その中にモデルが保存されました。
モデルのロード(失敗例)
以下のコードでモデルを読み込んでみます。
import torch
model = torch.nn.Linear(28 * 28, 10)
checkpoint = torch.load("lightning_logs/version_0/checkpoints/epoch=2-step=2813.ckpt")
model.load_state_dict(checkpoint["state_dict"])
すると次のようなエラーが出ました。
Missing key(s) in state_dict: "weight", "bias".
Unexpected key(s) in state_dict: "l1.weight", "l1.bias".
原因
まずcheckpointの中身を見てみます。
checkpointはdictなのでkeys()メソッドでキーを確認できます。
>>> checkpoint.keys()
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'callbacks', 'optimizer_states', 'lr_schedulers'])
この中で重み、バイアスを保存しているのは"state_dict"です。
>>> checkpoint["state_dict"].keys()
odict_keys(['l1.weight', 'l1.bias'])
"l1.weight", "l1.bias"というキーで重み、バイアスが保存されていることがわかりました。
次にモデルがどのようなキーで重み、バイアスを保存しているか確認してみます。
モデルが持つ層の情報はstate_dict()メソッドで確認できます。
>>> model = torch.nn.Linear(28 * 28, 10)
>>> model.state_dict().keys()
odict_keys(['weight', 'bias'])
モデルは"weight", "bias"というキーで重み、バイアスを保存していました。
このキーの値がずれている事がエラーの原因です。
ちなみに"l1.weight", "l1.bias"のl1はMNISTModel内の
self.l1 = torch.nn.Linear(28 * 28, 10)
から来ています。
解決策(その1)
一つ目の解決策はmodelをtorch.nn.linearではなくMNISTModelで定義する事です。
>>> model = MNISTModel()
>>> model.state_dict().keys()
odict_keys(['l1.weight', 'l1.bias'])
するとモデル側のキーの値が"l1.weight", "l1.bias"になっています。実際に学習済みモデルをロードすると以下のように正常に読み込むことができます。
>>> model.load_state_dict(checkpoint["state_dict"])
<All keys matched successfully>
解決策(その2)
基本的には解決策(その1)で良いと思うんですが、MNISTModelで色んなパッケージを使っていてpredictionのためだけにそれらを全てインストールしたくないという場合もあると思います。(trainとpredictionを別々のPCでやるときとか。)
そのような場合は新しくtorch.nn.Moduleを継承したクラスを定義してコンストラクタでMNISTModelと同じネットワークを定義する方法があります。
>>> class PredModel(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
>>> model = PredModel()
>>> model.load_state_dict(checkpoint["state_dict"])
<All keys matched successfully>
>>>
ここでメンバー変数の変数名をMNISTModelと合わせている事に注意してください。
変数名がstate_dictのキー名に影響するので合わせないとロードできません。
注意点
load_state_dictにはstrictという引数があります。
デフォルトはTrueですがFalseにするとキーの値が合うものだけロードして残りはロードしません。
>>> model = torch.nn.Linear(28 * 28, 10)
>>> model.load_state_dict(checkpoint["state_dict"], strict=False)
_IncompatibleKeys(missing_keys=['weight', 'bias'], unexpected_keys=['l1.weight', 'l1.bias'])
このようにエラーは出ませんがキーの値が違うので何もロードできていません。
個人的にstrict=Falseは使わない方が良いと思います。