RetNetの計算式は以下のようになっています:
Y_i=X_iW_Q\sum_{j\in[0,i]}A^{i-j}W_K^\mathsf{T}X_j^\mathsf{T}X_jW_V
ここで$A$は複素対角行列で、論文では対角化された行列$BAB^\mathsf{T}$の$B,B^\mathsf{T}$が$W_Q,W_K$にパラメータとして吸収されるというロジックを取っています。
簡単のために$Q_i=X_iW_Q, Z_i=W_K^\mathsf{T}X_i^\mathsf{T}X_iW_V$と置きます。
Y_i=Q_i\sum_{j\in[0,i]}A^{i-j}Z_j
時間方向に展開すると
\begin{align}
Y_0&=Q_0Z_0\\
Y_1&=Q_1(AZ_0+Z_1)\\
Y_2&=Q_2(A^2Z_0+AZ_1+Z_2)\\
Y_3&=Q_3(A^3Z_0+A^2Z_1+AZ_2+Z_3)\\
\end{align}
のようになり、括弧内最終項直前までを隠れ状態と$A$の積$AH_{i-1}$とみなすと
\begin{align}
H_i&=AH_{i-1}+Z_i\\
Y_i&=Q_iH_i
\end{align}
の更新式が得られます。
このような隠れ状態の係数$A$が時間$i$に依存しない線形システムは時不変システムと呼ばれ、Mambaの論文では既存の線形時不変システムの性能的な制限について考察がされています。
ここで$A$を時間依存の対角行列$A_i$とし、RetNetを線形時変システムとすることを考えます。
隠れ状態の更新は以下のようになります:
\begin{align}
H_0&=&A_0H_{-1}&+&Z_0&&&&&&\\
H_1&=&A_1A_0H_{-1}&+&A_1Z_0&+&Z_1&&&&\\
H_2&=&A_2A_1A_0H_{-1}&+&A_2A_1Z_0&+&A_2Z_1&+&Z_2&&\\
H_3&=&A_3A_2A_1A_0H_{-1}&+&A_3A_2A_1Z_0&+&A_3A_2Z_1&+&A_3Z_2&+&Z_3\\
\end{align}
ここで各$A_i$は対角行列のため、これらの積は行列要素毎の積として独立に考えることができます。
一要素のスカラーに着目して以下のように小文字で書きます:
\begin{align}
h_0&=&a_0h_{-1}&+&z_0&&&&&&\\
h_1&=&a_1a_0h_{-1}&+&a_1z_0&+&z_1&&&&\\
h_2&=&a2a_1a_0h_{-1}&+&a_2a_1z_0&+&a_2z_1&+&z_2&&\\
h_3&=&a_3a_2a_1a_0h_{-1}&+&a_3a_2a_1z_0&+&a_3a_2z_1&+&a_3z_2&+&z_3\\
\end{align}
ここで$z_2$までの列の係数行列
\begin{pmatrix}
a_0 & 1 & 0 & 0\\
a_1a_0 & a_1 & 1 & 0\\
a_2a_1a_0 & a_2a_1 & a_2 & 1\\
a_3a_2a_1a_0 & a_3a_2a_1 & a_3a_2 & a_3\\
\end{pmatrix}
を効率的に計算したいです。積の形だと計算が難しいので対数$l_i=\mathrm{log}(a_i)$を取ると係数行列は
\mathrm{exp}
\begin{pmatrix}
\begin{pmatrix}
l_0 & 0 & 0 & 0\\
l_1+l_0 & l_1 & 0 & 0\\
l_2+l_1+l_0 & l_2+l_1 & l_2 & 0\\
l_3+l_2+l_1+l_0 & l_3+l_2+l_1 & l_3+l_2 & l_3\\
\end{pmatrix}
\end{pmatrix}
\bigodot
\begin{pmatrix}
1 & 1 & 0 & 0\\
1 & 1 & 1 & 0\\
1 & 1 & 1 & 1\\
1 & 1 & 1 & 1\\
\end{pmatrix}
の形に書けます。ここで$\bigodot$はアダマール積です。
$\mathrm{exp}$の中身は行方向の畳み込み演算$*$を用いて
\begin{pmatrix}
\begin{pmatrix}
l_0 & l_0 & l_0 & l_0\\
l_1 & l_1 & l_1 & l_1\\
l_2 & l_2 & l_2 & l_2\\
l_3 & l_3 & l_3 & l_3\\
\end{pmatrix}
\bigodot
\begin{pmatrix}
1 & 0 & 0 & 0\\
1 & 1 & 0 & 0\\
1 & 1 & 1 & 0\\
1 & 1 & 1 & 1\\
\end{pmatrix}
\end{pmatrix}
*
\begin{pmatrix}
1 & 1 & 1 & 1\\
1 & 1 & 1 & 1\\
1 & 1 & 1 & 1\\
1 & 1 & 1 & 1\\
\end{pmatrix}
のように計算できます。
ここでは長さ4の系列を例として説明しましたが、もちろんこのアルゴリズムは任意の系列長に対して実行できます。
実装
-現在コードを整理してますので完了次第公開します 公開しました。
注意点としては$a_i$の絶対値は発散を避けるために1未満である必要があり、また対数を取るため0に近すぎてもいけません。
実装では$X_i$の線形変換で得た複素数を滅多に0にならない性質を利用しており、SigLogのような関数に通しています