0
0

PyTorchで形に制約を加えたパラメータを実装する方法

Posted at

PyTorchで、三角行列や対角行列などの形になるよう制約を加えてパラメータを実装する方法を紹介します。ここで紹介した方法を派生すれば、任意の形を保ったパラメータを実装できるはずです。

# 下三角行列の形のパラメータ     # 対角行列の形のパラメータ
[[a11,   0,   0,   0],         [[a11,   0,   0,   0],
 [a21, a22,   0,   0],          [  0, a22,   0,   0],  
 [a31, a32, a33,   0],          [  0,   0, a33,   0],
 [a41, a42, a43, a44]]          [  0,   0,   0, a44]]

以下では、三角行列対角行列の形のパラメータの実装方法を、例としてそれぞれ紹介します。

三角行列の形のパラメータ

例として、実装するモデルに次の式が登場したとします。

$$
A = LL^{\top}
$$

$A$は正方行列、$L$は下三角行列です。いずれも$n\times n$行列とします。$L$をパラメータとして実装し、$A$を計算で求めたい場合は、以下のプログラムで実現可能です。

import torch
import torch.nn as nn

class TriModel(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.n = n
        num_L_elem = n * (n + 1) // 2  # Lの下三角部分の要素数
        self.L_elem = nn.Parameter(torch.randn(num_L_elem))  # Lの要素をパラメータとして保持
	
	# モデルに紐づく定数を定義する。この定数は学習されない。
        # 定数は self.zero_mat で利用可能。
        self.zero_mat_ = torch.zeros(n, n)
        self.register_buffer("zero_mat", self.zero_mat_)

    def _make_L(self) -> torch.Tensor:
        """下三角行列Lを作る"""
        L = self.zero_mat.clone()  # 下三角行列のベースとなる零行列をコピー
        L_indices = torch.tril_indices(self.n, self.n)  # Lの下三角部分のインデックスをまとめて取得
        L[*L_indices] = self.L_elem  # self.L_elemのパラメータから、下三角行列を作成
        return L

    def _calc_A(self) -> torch.Tensor:
        """Aを計算する"""
        L = self._make_L()
        A = torch.matmul(L, L.t())
        return A

    (以下省略)

今回の例ではパラメータを下三角行列の形にしましたが、上三角行列の形にしたい場合は、_make_Lメソッドにあるtorch.tril_indices(...)1torch.triu_indices(...)2に書き換えます。

対角行列の形のパラメータ

先程と同様、L_indicesのようにインデックスをまとめた変数を用意すれば、対角行列($D$とします)の形のパラメータを実装できます。以下では、先程とは異なり、インデックスを直接指定する方法で実装しています。

import torch
import torch.nn as nn

class DiagModel(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.n = n
        num_D_elem = n  # Dの対角部分の要素数
        self.D_elem = nn.Parameter(torch.randn(num_D_elem))  # Dの要素をパラメータとして保持
	
	# モデルに紐づく定数を定義
        self.zero_mat_ = torch.zeros(n, n)
        self.register_buffer("zero_mat", self.zero_mat_)

    def _make_D(self) -> torch.Tensor:
        """対角行列Dを作る"""
        D = self.zero_mat.clone()  # 対角行列のベースとなる零行列をコピー
        D[range(self.n), range(self.n)] = self.D_elem
        return D

    (以下省略)

参考文献

  1. python - Enforcing a structure in a nn.Parameter (matrix) parameter in Pytorch - Stack Overflow
  2. PyTorch Moduleに紐づく定数のtensorを定義する|gota_morishita
  3. Why model.to(device) wouldn't put tensors on a custom layer to the same device? - PyTorch Forums
  4. What is the difference between register_buffer and register_parameter of nn.Module - PyTorch Forums
  5. python - Replace diagonal elements with vector in PyTorch - Stack Overflow
  1. torch.tril_indices — PyTorch 2.0 documentation

  2. torch.triu_indices — PyTorch 2.0 documentation

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0