179
119

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

pytorchで書いたモデルを保存するときの落とし穴

Posted at

Deep Learningのフレームワークとして最近伸びてきているpytorchを触ってみたら、モデルの保存で思いがけない落とし穴があったのでメモ。

概要

torch.save(the_model, PATH)

この方法で保存したモデルはGPUの情報を含んでいて、ロードのときに自動的に保存時のGPUに乗る。

公式ドキュメント

pytorchのモデルの保存について公式でまとめてあるサイト。
http://pytorch.org/docs/master/notes/serialization.html

モデルの保存方法には以下の2つの方法がある。

方法1

保存時
torch.save(the_model.state_dict(), PATH)
ロード時
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

方法2

保存時
torch.save(the_model, PATH)
ロード時
the_model = torch.load(PATH)

問題なのは2つめの方法

実際にやってみる

まず適当にモデルを組む。

model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()

        self.l1 = nn.Linear(100, 10)
        self.l2 = nn.Linear(10, 10)

    def forward(self, x):
        h1 = F.tanh(self.l1(x))
        return self.l2(h1)

モデルを作って、GPUに送って学習はせずに保存する。

create_model.py
import torch
import model

def main():
    # モデルの作成
    m = model.DummyModel()

    # GPUに転送
    m = m.cuda()

    # 保存
    torch.save(m, 'model')


if __name__ == '__main__':
    main()

ロードしてみる。

load_model.py
import torch
import model

def main():
    m = torch.load('model')

    for param in m.parameters():
        # 型を調べるとCPUかGPUかわかる。
        # CPU: torch.FloatTensor
        # GPU: torch.cuda.FloatTensor
        print(type(param.data))

if __name__ == '__main__':
    main()

結果:

% python load_model.py 
<class 'torch.cuda.FloatTensor'>
<class 'torch.cuda.FloatTensor'>
<class 'torch.cuda.FloatTensor'>
<class 'torch.cuda.FloatTensor'>

がっつりGPUに乗ってる。(CPUなら'torch.FloatTensor'なはず)

何が問題なのか。

  1. 学習時に使っていたGPUが他の人に使われているとロードすらできなくなる。
  2. 学習の途中でGPUを変えるのが非常にめんどくさい。

結論

公式ドキュメントでもおすすめされている1つ目の方法を使って、モデルは保存しましょう。

保存時
torch.save(the_model.state_dict(), PATH)
ロード時
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
179
119
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
179
119

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?