7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ckptファイルをload_state_dictで読み込んだ話

Posted at

前置き

pytorch_lightningを使って学習したモデルをload_state_dictを使って読み込もうとしたら"Missing key(s) in state_dict..."というエラーが出ました。
今回はこのエラーを解消する手順を説明します。

モデルの保存

モデルの学習と保存について説明します。
まずINTRODUCTION TO PYTORCH LIGHTNINGに書いてあるコードをコピペして実行します。

pl.py
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.models import efficientnet_b0
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
       return torch.optim.Adam(self.parameters(), lr=0.02)

if __name__ == "__main__":
    # Init our model
    mnist_model = MNISTModel()

    # Init DataLoader from MNIST Dataset
    train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

    # Initialize a trainer
    trainer = Trainer(
        gpus=AVAIL_GPUS,
        max_epochs=3,
        progress_bar_refresh_rate=20,
    )

    # Train the model
    trainer.fit(mnist_model, train_loader)

するとlightning_logsというディレクトリができて、その中にモデルが保存されました。

モデルのロード(失敗例)

以下のコードでモデルを読み込んでみます。

import torch

model = torch.nn.Linear(28 * 28, 10)
checkpoint = torch.load("lightning_logs/version_0/checkpoints/epoch=2-step=2813.ckpt")
model.load_state_dict(checkpoint["state_dict"])

すると次のようなエラーが出ました。

Missing key(s) in state_dict: "weight", "bias". 
	Unexpected key(s) in state_dict: "l1.weight", "l1.bias". 

原因

まずcheckpointの中身を見てみます。
checkpointはdictなのでkeys()メソッドでキーを確認できます。

>>> checkpoint.keys()
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'callbacks', 'optimizer_states', 'lr_schedulers'])

この中で重み、バイアスを保存しているのは"state_dict"です。

>>> checkpoint["state_dict"].keys()
odict_keys(['l1.weight', 'l1.bias'])

"l1.weight", "l1.bias"というキーで重み、バイアスが保存されていることがわかりました。
次にモデルがどのようなキーで重み、バイアスを保存しているか確認してみます。
モデルが持つ層の情報はstate_dict()メソッドで確認できます。

>>> model = torch.nn.Linear(28 * 28, 10)
>>> model.state_dict().keys()
odict_keys(['weight', 'bias'])

モデルは"weight", "bias"というキーで重み、バイアスを保存していました。
このキーの値がずれている事がエラーの原因です。
ちなみに"l1.weight", "l1.bias"のl1はMNISTModel内の

self.l1 = torch.nn.Linear(28 * 28, 10)

から来ています。

解決策(その1)

一つ目の解決策はmodelをtorch.nn.linearではなくMNISTModelで定義する事です。

>>> model = MNISTModel()
>>> model.state_dict().keys()
odict_keys(['l1.weight', 'l1.bias'])

するとモデル側のキーの値が"l1.weight", "l1.bias"になっています。実際に学習済みモデルをロードすると以下のように正常に読み込むことができます。

>>> model.load_state_dict(checkpoint["state_dict"])
<All keys matched successfully>

解決策(その2)

基本的には解決策(その1)で良いと思うんですが、MNISTModelで色んなパッケージを使っていてpredictionのためだけにそれらを全てインストールしたくないという場合もあると思います。(trainとpredictionを別々のPCでやるときとか。)
そのような場合は新しくtorch.nn.Moduleを継承したクラスを定義してコンストラクタでMNISTModelと同じネットワークを定義する方法があります。

>>> class PredModel(torch.nn.Module):
...   def __init__(self):
...     super().__init__()
...     self.l1 = torch.nn.Linear(28 * 28, 10)
... 
>>> model = PredModel()
>>> model.load_state_dict(checkpoint["state_dict"])
<All keys matched successfully>
>>> 

ここでメンバー変数の変数名をMNISTModelと合わせている事に注意してください。
変数名がstate_dictのキー名に影響するので合わせないとロードできません。

注意点

load_state_dictにはstrictという引数があります。
デフォルトはTrueですがFalseにするとキーの値が合うものだけロードして残りはロードしません。

>>> model = torch.nn.Linear(28 * 28, 10)
>>> model.load_state_dict(checkpoint["state_dict"], strict=False)  
_IncompatibleKeys(missing_keys=['weight', 'bias'], unexpected_keys=['l1.weight', 'l1.bias'])

このようにエラーは出ませんがキーの値が違うので何もロードできていません。
個人的にstrict=Falseは使わない方が良いと思います。

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?