PyTorch で torch.nn.Module を継承したモデルに torch.nn.Module.register_buffer によって登録したテンソルは、勾配減少方向への更新は受けないが、モデルの状態としてセーブ&ロードされます。以下の手順でそれを確認できます。
確認手順
MyModel というダミーモデルを用意します。このモデルはサイズ (*, 2) の入力をサイズ (*, 3) に線形変換し、サイズ (3,) の係数 coef を足します。ただし、サイズ (3,) の係数はインスタンス化時に指定し、 register_buffer でモデルに登録します。
-
State 0. モデルをインスタンス化すると、線形変換の重みとバイアスには
requires_grad=Trueがついていますが、係数coefにはついていません。 -
State 1. 実際にモデルを 1 ステップだけ更新すると、線形変換の重みとバイアスは更新されますが、係数
coefは更新されません。このモデルの状態をセーブします。 -
State 2. 新しいモデルを新しい
coefの値でインスタンス化すると、このモデルの状態は当然 State 1. とは異なるものになっています。 -
State 3. そこに先ほどセーブしたモデルの状態をロードすると、
coefも含めて State 1. の状態が復元されます。
script.py
import torch
class MyModel(torch.nn.Module): # ダミーモデル
def __init__(self, coef):
super().__init__()
self.layer0 = torch.nn.Linear(2, 3)
self.register_buffer('coef', torch.tensor(coef, dtype=torch.float))
def dumps(self):
return f'{self.layer0.weight}\n{self.layer0.bias}\n{self.coef}\n'
def forward(self, x):
y = self.layer0(x)
y += self.coef
return y
def main_0():
model = MyModel([0.1, 0.2, 0.3])
print('State 0. ' + model.dumps())
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
optimizer.zero_grad()
x = torch.tensor([[1., 1.]]) # ダミー入力
y = model(x)
print(f'{y=}\n')
true = torch.tensor([[0.5, 0.5, 0.5]]) # ダミー正解
loss = ((y - true) ** 2).mean() # 2 乗誤差
loss.backward()
optimizer.step()
print('State 1. ' + model.dumps())
torch.save(model.state_dict(), 'hoge.pth')
def main_1():
model = MyModel([0., 0., 0.])
print('State 2. ' + model.dumps())
model.load_state_dict(torch.load('hoge.pth'))
print('State 3. ' + model.dumps())
if __name__ == '__main__':
torch.manual_seed(1)
main_0()
main_1()
script.py の実行結果
State 0. Parameter containing:
tensor([[ 0.3643, -0.3121],
[-0.1371, 0.3319],
[-0.6657, 0.4241]], requires_grad=True)
Parameter containing:
tensor([-0.1455, 0.3597, 0.0983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])
y=tensor([[0.0068, 0.7545, 0.1567]], grad_fn=<AddBackward0>)
State 1. Parameter containing:
tensor([[ 0.4643, -0.2121],
[-0.2371, 0.2319],
[-0.5657, 0.5241]], requires_grad=True)
Parameter containing:
tensor([-0.0455, 0.2597, 0.1983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])
State 2. Parameter containing:
tensor([[-0.0866, 0.1961],
[ 0.0349, 0.2583],
[-0.2756, -0.0516]], requires_grad=True)
Parameter containing:
tensor([-0.0637, 0.1025, -0.0028], requires_grad=True)
tensor([0., 0., 0.])
State 3. Parameter containing:
tensor([[ 0.4643, -0.2121],
[-0.2371, 0.2319],
[-0.5657, 0.5241]], requires_grad=True)
Parameter containing:
tensor([-0.0455, 0.2597, 0.1983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])