1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torch.nn.utils.parameters_to_vector / vector_to_parameters の挙動確認メモ

1
Posted at

PyTorch には、モデルの学習可能なパラメータを 1 本のベクトルにする関数と、逆に 1 本のベクトルをモデルのパラメータにセットする関数があります。

おそらく以下などのときにモデルのパラメータを 1 本のベクトルにすると便利です。

  • モデル全体のノルムに制約を入れたい。
  • 2 つのモデルの中点を取りたい。

以下のスクリプトでモデルの学習可能なパラメータを 1 本のベクトルにし、それを全て 0.5 にして差し戻せることを確認できます。register_buffer によって登録した学習対象でないテンソルはそのままです。

import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
torch.manual_seed(0)

class MyModel(torch.nn.Module):  # ダミーモデル
    def __init__(self, coef):
        super().__init__()
        self.layer0 = torch.nn.Linear(2, 3)
        self.register_buffer('coef', torch.tensor(coef, dtype=torch.float))
    def dumps(self):
        return f'{self.layer0.weight}\n{self.layer0.bias}\n{self.coef}\n'
    def forward(self, x):
        y = self.layer0(x)
        y += self.coef
        return y

print('\n---------- Model parameters ----------')
model = MyModel([0.1, 0.2, 0.3])
print(model.dumps())

print('\n---------- Flattened vector ----------')
vec = parameters_to_vector(model.parameters())
print(vec)

print('\n---------- Flattened vector ----------')
vec[:] = 0.5  # 重みを全部 0.5 にする
print(vec)

print('\n---------- Model parameters ----------')
vector_to_parameters(vec, model.parameters())
print(model.dumps())
---------- Model parameters ----------
Parameter containing:
tensor([[-0.0053,  0.3793],
        [-0.5820, -0.5204],
        [-0.2723,  0.1896]], requires_grad=True)
Parameter containing:
tensor([-0.0140,  0.5607, -0.0628], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])

---------- Flattened vector ----------
tensor([-0.0053,  0.3793, -0.5820, -0.5204, -0.2723,  0.1896, -0.0140,  0.5607,
        -0.0628], grad_fn=<CatBackward0>)

---------- Flattened vector ----------
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
       grad_fn=<FillBackward3>)

---------- Model parameters ----------
Parameter containing:
tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000]], requires_grad=True)
Parameter containing:
tensor([0.5000, 0.5000, 0.5000], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?