PyTorchで、三角行列や対角行列などの形になるよう制約を加えてパラメータを実装する方法を紹介します。ここで紹介した方法を派生すれば、任意の形を保ったパラメータを実装できるはずです。
# 下三角行列の形のパラメータ # 対角行列の形のパラメータ
[[a11, 0, 0, 0], [[a11, 0, 0, 0],
[a21, a22, 0, 0], [ 0, a22, 0, 0],
[a31, a32, a33, 0], [ 0, 0, a33, 0],
[a41, a42, a43, a44]] [ 0, 0, 0, a44]]
以下では、三角行列と対角行列の形のパラメータの実装方法を、例としてそれぞれ紹介します。
三角行列の形のパラメータ
例として、実装するモデルに次の式が登場したとします。
$$
A = LL^{\top}
$$
$A$は正方行列、$L$は下三角行列です。いずれも$n\times n$行列とします。$L$をパラメータとして実装し、$A$を計算で求めたい場合は、以下のプログラムで実現可能です。
import torch
import torch.nn as nn
class TriModel(nn.Module):
def __init__(self, n) -> None:
super().__init__()
self.n = n
num_L_elem = n * (n + 1) // 2 # Lの下三角部分の要素数
self.L_elem = nn.Parameter(torch.randn(num_L_elem)) # Lの要素をパラメータとして保持
# モデルに紐づく定数を定義する。この定数は学習されない。
# 定数は self.zero_mat で利用可能。
self.zero_mat_ = torch.zeros(n, n)
self.register_buffer("zero_mat", self.zero_mat_)
def _make_L(self) -> torch.Tensor:
"""下三角行列Lを作る"""
L = self.zero_mat.clone() # 下三角行列のベースとなる零行列をコピー
L_indices = torch.tril_indices(self.n, self.n) # Lの下三角部分のインデックスをまとめて取得
L[*L_indices] = self.L_elem # self.L_elemのパラメータから、下三角行列を作成
return L
def _calc_A(self) -> torch.Tensor:
"""Aを計算する"""
L = self._make_L()
A = torch.matmul(L, L.t())
return A
(以下省略)
今回の例ではパラメータを下三角行列の形にしましたが、上三角行列の形にしたい場合は、_make_L
メソッドにあるtorch.tril_indices(...)
1をtorch.triu_indices(...)
2に書き換えます。
対角行列の形のパラメータ
先程と同様、L_indices
のようにインデックスをまとめた変数を用意すれば、対角行列($D$とします)の形のパラメータを実装できます。以下では、先程とは異なり、インデックスを直接指定する方法で実装しています。
import torch
import torch.nn as nn
class DiagModel(nn.Module):
def __init__(self, n) -> None:
super().__init__()
self.n = n
num_D_elem = n # Dの対角部分の要素数
self.D_elem = nn.Parameter(torch.randn(num_D_elem)) # Dの要素をパラメータとして保持
# モデルに紐づく定数を定義
self.zero_mat_ = torch.zeros(n, n)
self.register_buffer("zero_mat", self.zero_mat_)
def _make_D(self) -> torch.Tensor:
"""対角行列Dを作る"""
D = self.zero_mat.clone() # 対角行列のベースとなる零行列をコピー
D[range(self.n), range(self.n)] = self.D_elem
return D
(以下省略)
参考文献
- python - Enforcing a structure in a nn.Parameter (matrix) parameter in Pytorch - Stack Overflow
- PyTorch Moduleに紐づく定数のtensorを定義する|gota_morishita
- Why model.to(device) wouldn't put tensors on a custom layer to the same device? - PyTorch Forums
- What is the difference between
register_buffer
andregister_parameter
ofnn.Module
- PyTorch Forums - python - Replace diagonal elements with vector in PyTorch - Stack Overflow