0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torch.nn.Module.register_buffer の挙動確認メモ

0
Posted at

PyTorch で torch.nn.Module を継承したモデルに torch.nn.Module.register_buffer によって登録したテンソルは、勾配減少方向への更新は受けないが、モデルの状態としてセーブ&ロードされます。以下の手順でそれを確認できます。

確認手順

MyModel というダミーモデルを用意します。このモデルはサイズ (*, 2) の入力をサイズ (*, 3) に線形変換し、サイズ (3,) の係数 coef を足します。ただし、サイズ (3,) の係数はインスタンス化時に指定し、 register_buffer でモデルに登録します。

  • State 0. モデルをインスタンス化すると、線形変換の重みとバイアスには requires_grad=True がついていますが、係数 coef にはついていません。
  • State 1. 実際にモデルを 1 ステップだけ更新すると、線形変換の重みとバイアスは更新されますが、係数 coef は更新されません。このモデルの状態をセーブします。
  • State 2. 新しいモデルを新しい coef の値でインスタンス化すると、このモデルの状態は当然 State 1. とは異なるものになっています。
  • State 3. そこに先ほどセーブしたモデルの状態をロードすると、coef も含めて State 1. の状態が復元されます。
script.py
import torch


class MyModel(torch.nn.Module):  # ダミーモデル
    def __init__(self, coef):
        super().__init__()
        self.layer0 = torch.nn.Linear(2, 3)
        self.register_buffer('coef', torch.tensor(coef, dtype=torch.float))
    def dumps(self):
        return f'{self.layer0.weight}\n{self.layer0.bias}\n{self.coef}\n'
    def forward(self, x):
        y = self.layer0(x)
        y += self.coef
        return y


def main_0():
    model = MyModel([0.1, 0.2, 0.3])
    print('State 0. ' + model.dumps())

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    optimizer.zero_grad()
    x = torch.tensor([[1., 1.]])  # ダミー入力
    y = model(x)
    print(f'{y=}\n')
    true = torch.tensor([[0.5, 0.5, 0.5]])  # ダミー正解
    loss = ((y - true) ** 2).mean()  # 2 乗誤差
    loss.backward()
    optimizer.step()
    print('State 1. ' + model.dumps())
    torch.save(model.state_dict(), 'hoge.pth')


def main_1():
    model = MyModel([0., 0., 0.])
    print('State 2. ' + model.dumps())

    model.load_state_dict(torch.load('hoge.pth'))
    print('State 3. ' + model.dumps())


if __name__ == '__main__':
    torch.manual_seed(1)
    main_0()
    main_1()
script.py の実行結果
State 0. Parameter containing:
tensor([[ 0.3643, -0.3121],
        [-0.1371,  0.3319],
        [-0.6657,  0.4241]], requires_grad=True)
Parameter containing:
tensor([-0.1455,  0.3597,  0.0983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])

y=tensor([[0.0068, 0.7545, 0.1567]], grad_fn=<AddBackward0>)

State 1. Parameter containing:
tensor([[ 0.4643, -0.2121],
        [-0.2371,  0.2319],
        [-0.5657,  0.5241]], requires_grad=True)
Parameter containing:
tensor([-0.0455,  0.2597,  0.1983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])

State 2. Parameter containing:
tensor([[-0.0866,  0.1961],
        [ 0.0349,  0.2583],
        [-0.2756, -0.0516]], requires_grad=True)
Parameter containing:
tensor([-0.0637,  0.1025, -0.0028], requires_grad=True)
tensor([0., 0., 0.])

State 3. Parameter containing:
tensor([[ 0.4643, -0.2121],
        [-0.2371,  0.2319],
        [-0.5657,  0.5241]], requires_grad=True)
Parameter containing:
tensor([-0.0455,  0.2597,  0.1983], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?