目次
はじめに
Pytorchモデルの保存・読み込みは,以下のような方法で行うことができます。
保存
torch.save(the_model.state_dict(), PATH)
読み込み
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
ただこの方法だと保存ファイルの他に,モデルのクラスやその引数なども覚えておく必要があり,それがちょっと不便でした。
ということで,これらを一緒に保存して保存ファイルのパスだけで読み込めるような関数を書きました。
追記 (2019/06/07)
改めて調べてみたら,もっと簡単な方法を見つけてしまいました。
cloudpickleモジュールを使用します。
import cloudpickle
保存
with open('model.pkl', 'wb') as f:
cloudpickle.dump(model, f)
読み込み
with open('model.pkl', 'rb') as f:
model = cloudpickle.load(f)
参考
https://qiita.com/kojisuganuma/items/e9b29e8e5ef5f5f289b2
保存用関数
保存
import inspect
import pickle
from copy import deepcopy
def save_model(model, filename, args=[], kwargs={}):
''' モデル・重みの保存 '''
model_cpu = deepcopy(model).cpu()
state = {'module_path': inspect.getmodule(model, _filename=True).__file__,
'class_name': model.__class__.__name__,
'state_dict': model_cpu.state_dict(),
'args': args,
'kwargs': kwargs}
with open(filename, 'wb') as f:
pickle.dump(state, f)
-
module_path
: モデルのクラスが定義されているファイルのパス -
class_name
: モデルのクラス名 -
state_dict
: モデルの重み -
args
: モデルの引数(リスト引数) -
kwargs
: モデルの引数(キーワード引数)
読み込み用関数
読み込み
import pickle
from importlib import machinery
def load_model(filename):
''' モデル・重みの読み込み '''
with open(filename, 'rb') as f:
state = pickle.load(f)
module = machinery.SourceFileLoader(
state['module_path'], state['module_path']).load_module()
args, kwargs = state['args'], state['kwargs']
model = getattr(module, state['class_name'])(*args, **kwargs)
model.load_state_dict(state['state_dict'])
return model
<注意>
save_model()
で保存されるファイルパスは今いるPC上での絶対パスなので,保存ファイルを別PCに転送して使う場合には,load_model()
で読み込めない可能性があります。(絶対パスが同じ表記なら問題なし)
load_model()
が使えない場合は,保存ファイルからstate_dict
, args
, kwargs
だけ取り出して使用して下さい。
使用例
使用例
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self, input_size, output_size):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x))
return x
if __name__ == '__main__':
# モデル定義
model = Net(128, 2)
print(model)
# 保存
save_model(model, 'net.pkl', args=[128, 2])
# 読み込み
model = load_model('net.pkl')
print(model)
出力
Net(
(fc1): Linear(in_features=128, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=32, bias=True)
(fc3): Linear(in_features=32, out_features=2, bias=True)
)
Net(
(fc1): Linear(in_features=128, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=32, bias=True)
(fc3): Linear(in_features=32, out_features=2, bias=True)
)