PyTorch には、モデルの学習可能なパラメータを 1 本のベクトルにする関数と、逆に 1 本のベクトルをモデルのパラメータにセットする関数があります。
- torch.nn.utils.parameters_to_vector — PyTorch 2.10 documentation
- torch.nn.utils.vector_to_parameters — PyTorch 2.10 documentation
おそらく以下などのときにモデルのパラメータを 1 本のベクトルにすると便利です。
- モデル全体のノルムに制約を入れたい。
- 2 つのモデルの中点を取りたい。
以下のスクリプトでモデルの学習可能なパラメータを 1 本のベクトルにし、それを全て 0.5 にして差し戻せることを確認できます。register_buffer によって登録した学習対象でないテンソルはそのままです。
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
torch.manual_seed(0)
class MyModel(torch.nn.Module): # ダミーモデル
def __init__(self, coef):
super().__init__()
self.layer0 = torch.nn.Linear(2, 3)
self.register_buffer('coef', torch.tensor(coef, dtype=torch.float))
def dumps(self):
return f'{self.layer0.weight}\n{self.layer0.bias}\n{self.coef}\n'
def forward(self, x):
y = self.layer0(x)
y += self.coef
return y
print('\n---------- Model parameters ----------')
model = MyModel([0.1, 0.2, 0.3])
print(model.dumps())
print('\n---------- Flattened vector ----------')
vec = parameters_to_vector(model.parameters())
print(vec)
print('\n---------- Flattened vector ----------')
vec[:] = 0.5 # 重みを全部 0.5 にする
print(vec)
print('\n---------- Model parameters ----------')
vector_to_parameters(vec, model.parameters())
print(model.dumps())
---------- Model parameters ----------
Parameter containing:
tensor([[-0.0053, 0.3793],
[-0.5820, -0.5204],
[-0.2723, 0.1896]], requires_grad=True)
Parameter containing:
tensor([-0.0140, 0.5607, -0.0628], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])
---------- Flattened vector ----------
tensor([-0.0053, 0.3793, -0.5820, -0.5204, -0.2723, 0.1896, -0.0140, 0.5607,
-0.0628], grad_fn=<CatBackward0>)
---------- Flattened vector ----------
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
grad_fn=<FillBackward3>)
---------- Model parameters ----------
Parameter containing:
tensor([[0.5000, 0.5000],
[0.5000, 0.5000],
[0.5000, 0.5000]], requires_grad=True)
Parameter containing:
tensor([0.5000, 0.5000, 0.5000], requires_grad=True)
tensor([0.1000, 0.2000, 0.3000])