背景
- 学習は GPU, 推論は CPU で行うなどで, torch.load や torch.save の仕組みを知りたい
- weight のデータをちょこっと編集したりとか, weight を自前 C++ 機械学習アプリなどで読み込みたい
ソースコードツアー
python レイヤー(torch.load
, torch.save
)は, torch/serialization.py
に実装があります.
libtorch(C++)関連
torch/csrc
にあります.
シリアライズ
torch.save
, torch.load
では, python pickle で Python object(model + weights)をシリアライズします.
これにより, モデルの現状(~= checkpoint. model + weight)を保存, ロードすることができるようになっています.
対応する Python のスクリプトも保存されるので, torch.load
でモデルを復元した場合は, 元の model のソースコードは不要になります(たぶん)
model の Python object instance(?) を取得する場合は, 通常 torch.load("checkpoint.pt")['model']
で取得できます.
libtorch
libtorch(C++)では, Python のサブセットライクな TorchScript 形式でモデルを読めますが, どうも Python コード or Python に近い言語を一から実装して, 処理しているようですね. なかなかの力技です...
weights だけの場合は, NPZ 形式(npy の複数ファイルを zip uncompressed でひとつにまとめたもの)での対応になっています.
map_location
CUDA で学習してシリアライズしたデータだと, CUDA の情報もふくまれているため, それを CPU only の環境で torch.load
するとエラーになります. これを解決するものとして map_location
オプション引数があります.
map_location に指定できるものとしてはいろいろパターンがあり, 関数(callable)を指定も可能になっています.
データコンバートメインであれば, "cpu"
文字列, もしくは torch.device('cpu')
指定で事足りるかと思います.
state_dict
モデルの weight データ(parameter)が保存されている Python dictionary(普通の Python での OrderedDict 形式)です.
Weight(パラメータ)の名前
weight にはなにかしら unique な名前がつきます.
self.my_name = nn.Linear
とすると, my_name
で Linear の weight に名前がつきます.
配列(nn.ModuleList)の場合は layers.0.my_name, my_layers.1.my_name, ... のように, .N.
のルールでつくようです.
したがって, train したモデルと, inference(TorchScript, libtorch)でモデルが異なる場合, うまく名前を見てコンバートする必要がああります.
(たとえば現状(v1.4.0)では TorchScript は nn.ModuleList
の配列アクセスが扱えないので, 手動で展開とか, 別クラスを作って iterate させるときなど)
load_state_dict
load_state_dict()
で weight データを読み込めます.
torch.load や, dict を直接指定(重みデータは torch.Tensor
がベター)でも OK です.
dict は, ネットワークのすべてのレイヤー名に対応する key が設定されている必要があります.
レイヤーの一部の weight だけ変えたい場合は,
conv0.weight = torch.Tensor(...)
conv0.bias = torch.Tensor(...)
のようにすることができます. ただこれだと shape のチェックが入らないので注意です.
(なにか他にいい方法がありそうな気はするが)
事例
checkpoint データから, weight(state_dict) だけ書き出して, 別のモデルへ読み込む(weight を設定する).
device = torch.device('cpu')
model = torch.load("pretrained.pt", map_location=device)['model']
torch.save(model.state_dict(), "weights.pt"))
weights = torch.load("weights.pt", map_location=device)
my_model = MyModel()
my_model.load_state_dict(weights)
モデルと weight 全体(pickle シリアリゼーション)も, weight(state_dict) だけもどちらも torch.load
で読み込めるのでやや混乱しますね.
双方のモデルでテンソルの大きさなどが違うとか, MyModel に対応する weight が無いとかだと, エラーを出してくれます.
libtorch
Python オブジェクトを読み書きする pickle_load, pickle_save があります
pickle 対応のために, libtorch(C++) に自前 pickle 実装してます!
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/serialization/pickler.h
テンソルだけロード, 保存は現状できなさそうです(なにかしら Model を作る必要がある)
重みのテンソルデータだけ(基本 float32 型を想定)やりとりしたい場合は, NPY, NPZ で Numpy 形式にして, cnpy を使うのがよいでしょう. https://github.com/rogersce/cnpy
nanosnap にも cnpy でテンソルデータを読む機能があります.