5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch/libtorch の load/save メモ(主にデータコンバート向け)

Last updated at Posted at 2020-04-08

背景

  • 学習は GPU, 推論は CPU で行うなどで, torch.load や torch.save の仕組みを知りたい
  • weight のデータをちょこっと編集したりとか, weight を自前 C++ 機械学習アプリなどで読み込みたい

ソースコードツアー

python レイヤー(torch.load, torch.save)は, torch/serialization.py に実装があります.

libtorch(C++)関連

torch/csrc にあります.

シリアライズ

torch.save, torch.load では, python pickle で Python object(model + weights)をシリアライズします.
これにより, モデルの現状(~= checkpoint. model + weight)を保存, ロードすることができるようになっています.

対応する Python のスクリプトも保存されるので, torch.load でモデルを復元した場合は, 元の model のソースコードは不要になります(たぶん)

model の Python object instance(?) を取得する場合は, 通常 torch.load("checkpoint.pt")['model'] で取得できます.

libtorch

libtorch(C++)では, Python のサブセットライクな TorchScript 形式でモデルを読めますが, どうも Python コード or Python に近い言語を一から実装して, 処理しているようですね. なかなかの力技です...

weights だけの場合は, NPZ 形式(npy の複数ファイルを zip uncompressed でひとつにまとめたもの)での対応になっています.

map_location

CUDA で学習してシリアライズしたデータだと, CUDA の情報もふくまれているため, それを CPU only の環境で torch.load するとエラーになります. これを解決するものとして map_location オプション引数があります.
map_location に指定できるものとしてはいろいろパターンがあり, 関数(callable)を指定も可能になっています.

データコンバートメインであれば, "cpu" 文字列, もしくは torch.device('cpu') 指定で事足りるかと思います.

state_dict

モデルの weight データ(parameter)が保存されている Python dictionary(普通の Python での OrderedDict 形式)です.

Weight(パラメータ)の名前

weight にはなにかしら unique な名前がつきます.

self.my_name = nn.Linear

とすると, my_name で Linear の weight に名前がつきます.
配列(nn.ModuleList)の場合は layers.0.my_name, my_layers.1.my_name, ... のように, .N. のルールでつくようです.

したがって, train したモデルと, inference(TorchScript, libtorch)でモデルが異なる場合, うまく名前を見てコンバートする必要がああります.
(たとえば現状(v1.4.0)では TorchScript は nn.ModuleList の配列アクセスが扱えないので, 手動で展開とか, 別クラスを作って iterate させるときなど)

load_state_dict

load_state_dict() で weight データを読み込めます.
torch.load や, dict を直接指定(重みデータは torch.Tensor がベター)でも OK です.

dict は, ネットワークのすべてのレイヤー名に対応する key が設定されている必要があります.
レイヤーの一部の weight だけ変えたい場合は,

conv0.weight = torch.Tensor(...)
conv0.bias = torch.Tensor(...)

のようにすることができます. ただこれだと shape のチェックが入らないので注意です.
(なにか他にいい方法がありそうな気はするが)

事例

checkpoint データから, weight(state_dict) だけ書き出して, 別のモデルへ読み込む(weight を設定する).

device = torch.device('cpu')

model = torch.load("pretrained.pt", map_location=device)['model']

torch.save(model.state_dict(), "weights.pt"))

weights = torch.load("weights.pt", map_location=device)

my_model = MyModel()
my_model.load_state_dict(weights)

モデルと weight 全体(pickle シリアリゼーション)も, weight(state_dict) だけもどちらも torch.load で読み込めるのでやや混乱しますね.

双方のモデルでテンソルの大きさなどが違うとか, MyModel に対応する weight が無いとかだと, エラーを出してくれます.

libtorch

Python オブジェクトを読み書きする pickle_load, pickle_save があります

pickle 対応のために, libtorch(C++) に自前 pickle 実装してます!
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/serialization/pickler.h

テンソルだけロード, 保存は現状できなさそうです(なにかしら Model を作る必要がある)

重みのテンソルデータだけ(基本 float32 型を想定)やりとりしたい場合は, NPY, NPZ で Numpy 形式にして, cnpy を使うのがよいでしょう. https://github.com/rogersce/cnpy

nanosnap にも cnpy でテンソルデータを読む機能があります.

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?