LoginSignup
45
34

More than 3 years have passed since last update.

【Python】Pytorchのモデル保存・読み込みをもっと簡単にしたい!

Last updated at Posted at 2019-06-06

目次

はじめに

Pytorchモデルの保存・読み込みは,以下のような方法で行うことができます。

保存
torch.save(the_model.state_dict(), PATH)
読み込み
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

ただこの方法だと保存ファイルの他に,モデルのクラスやその引数なども覚えておく必要があり,それがちょっと不便でした。

ということで,これらを一緒に保存して保存ファイルのパスだけで読み込めるような関数を書きました。

追記 (2019/06/07)

改めて調べてみたら,もっと簡単な方法を見つけてしまいました。
cloudpickleモジュールを使用します。

import cloudpickle
保存
with open('model.pkl', 'wb') as f:
    cloudpickle.dump(model, f)
読み込み
with open('model.pkl', 'rb') as f:
    model = cloudpickle.load(f)

参考
https://qiita.com/kojisuganuma/items/e9b29e8e5ef5f5f289b2

保存用関数

保存
import inspect
import pickle
from copy import deepcopy

def save_model(model, filename, args=[], kwargs={}):
    ''' モデル・重みの保存 '''
    model_cpu = deepcopy(model).cpu()
    state = {'module_path': inspect.getmodule(model, _filename=True).__file__,
             'class_name': model.__class__.__name__,
             'state_dict': model_cpu.state_dict(),
             'args': args,
             'kwargs': kwargs}
    with open(filename, 'wb') as f:
        pickle.dump(state, f)
  • module_path : モデルのクラスが定義されているファイルのパス
  • class_name : モデルのクラス名
  • state_dict : モデルの重み
  • args : モデルの引数(リスト引数)
  • kwargs : モデルの引数(キーワード引数)

読み込み用関数

読み込み
import pickle
from importlib import machinery

def load_model(filename):
    ''' モデル・重みの読み込み '''
    with open(filename, 'rb') as f:
        state = pickle.load(f)
    module = machinery.SourceFileLoader(
        state['module_path'], state['module_path']).load_module()
    args, kwargs = state['args'], state['kwargs']
    model = getattr(module, state['class_name'])(*args, **kwargs)
    model.load_state_dict(state['state_dict'])
    return model

<注意>
save_model()で保存されるファイルパスは今いるPC上での絶対パスなので,保存ファイルを別PCに転送して使う場合には,load_model()で読み込めない可能性があります。(絶対パスが同じ表記なら問題なし)
load_model()が使えない場合は,保存ファイルからstate_dict, args, kwargsだけ取り出して使用して下さい。

使用例

使用例
import torch.nn.functional as F
from torch import nn

class Net(nn.Module):
    def __init__(self, input_size, output_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.log_softmax(self.fc3(x))
        return x


if __name__ == '__main__':
    # モデル定義
    model = Net(128, 2)
    print(model)
    # 保存
    save_model(model, 'net.pkl', args=[128, 2])
    # 読み込み
    model = load_model('net.pkl')
    print(model)
出力
Net(
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
)
Net(
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
)
45
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
45
34