https://qiita.com/3405691582/items/c6fa00e58181b6bb6ca5 の続きのようなメモです。
RetNetの隠れ状態を含んだ形の計算式は以下のように書けます:
O_n = Q_n(\sum_{m=0}^{n}A^{n-m}K_m^{\intercal}V_m + A^{m+1}H_{-1})
ここで系列長$n$、次元数$d$について$n\times d$行列のQuery, Key, (及びValue)は論文中式(4)で示されているxPosへの展開$Q=(XW_Q)\bigodot\Theta$, $K=(XW_K)\bigodot\bar{\Theta}$, $\Theta_n=e^{in\theta}$がされているものとし、$d\times d$行列$A$は対角行列とします。$d\times d$行列$H_{-1}$は隠れ状態です。
論文中では係数パラメータ$A$は入力に依存しないパラメータとしていますが、ここでは入力依存の係数$A_n$を考えます。
このとき上式は以下のように書き換えられます:
O_n = Q_n(\sum_{m=0}^{n}A_nA_{n-1}\cdots A_{m+2}A_{m+1}K_m^{\intercal}v_m + A_nA_{n-1}\cdots A_1A_0H_{-1})
ただし$A_nA_{n-1}\cdots A_{m+2}A_{m+1}$は$m=n$のとき単位行列$E$とします。
わかりやすさのために$n=3$まで展開してみます:
\begin{align}
O_0&=&Q_0(K_0^\intercal V_0&+&A_0H_{-1})\\
O_1&=&Q_1(K_1^\intercal V_1&+&A_1K_0^\intercal V_0&+&A_1A_0H_{-1})\\
O_2&=&Q_2(K_2^\intercal V_2&+&A_2K_1^\intercal V_1&+&A_2A_1K_0^\intercal V_0&+&A_2A_1A_0H_{-1})\\
O_3&=&Q_3(K_3^\intercal V_3&+&A_3K_2^\intercal V_2&+&A_3A_2K_1^\intercal V_1&+&A_3A_2A_1K_0^\intercal V_0&+&A_3A_2A_1A_0H_{-1})\\
\end{align}
RetNetでは入力系列のチャンク毎に計算を行うことで効率化をしており、$[0,n)$の範囲のチャンク毎の計算は次のようになります:
\begin{align}
O_n&=&(FV)_n + Q_nA_{n-1}\cdots A_0H_{-1}\\
F_{ij}&=&\sum_k(Q_{[0,n)})_{ik}D_{ikj}(K_{[0,n)}^\intercal)_{kj}\\
D_{ikj}&=&
\begin{cases}
0 & \text{if $i<j$}\\
1 & \text{if $i=j$}\\
(A_i)_{kk}\cdots(A_{j+1})_{kk} & \text{otherwise}
\end{cases}
\end{align}
$D_{ikj}$のイメージがつきづらいので$A_n$の$(0,0)$成分を$a_n$とすると$D_{i0j}$行列は以下のように書けます:
D_{i0j}=
\begin{pmatrix}
1 & 0 & 0 & 0\\
a_1 & 1 & 0 & 0\\
a_2a_1 & a_2 & 1 & 0\\
a_3a_2a_1 & a_3a_2 & a_3 & 1\\
\end{pmatrix}
問題はこの$D_{ikj}$は系列長×系列長×次元数のサイズとなるためナイーブな方法ではメモリに乗らないことです。RetNetでは$D_{ikj}$の$k$軸を潰して$A_n$をヘッド毎のスカラーとすることで対応しています。
注意するとこの行列は以下のようなベクトル同士の行列積をマスクしたものとみなすことができます:
\begin{align}
D_{i0j}&=&
\begin{pmatrix}
1 & 0 & 0 & 0\\
a_1 & 1 & 0 & 0\\
a_2a_1 & a_2 & 1 & 0\\
a_3a_2a_1 & a_3a_2 & a_3 & 1\\
\end{pmatrix}\\
&=&
\Big(
\begin{pmatrix}
1\\
a_1\\
a_2a_1\\
a_3a_2a_1
\end{pmatrix}
\begin{pmatrix}
1 & a_1^{-1} & (a_2a_1)^{-1} & (a_3a_2a_1)^{-1}
\end{pmatrix}
\Big)
\bigodot
\begin{pmatrix}
1 & 0 & 0 & 0\\
1 & 1 & 0 & 0\\
1 & 1 & 1 & 0\\
1 & 1 & 1 & 1
\end{pmatrix}
\end{align}
したがって
\begin{align}
B_{ik}&=&
\begin{cases}
1 & \text{if i=0}\\
(A_i)_{kk}\cdots(A_1)_{kk} & \text{otherwise}
\end{cases}\\
(1/B)_{ik} &=& 1/B_{ik}\\
M &=& \begin{pmatrix}
1 & 0 & 0 & 0\\
1 & 1 & 0 & 0\\
1 & 1 & 1 & 0\\
1 & 1 & 1 & 1
\end{pmatrix}
\end{align}
としたとき
F_{ij}=\sum_k(Q_{[0,n)})_{ik}B_{ik}(1/B)_{jk}(K_{[0,n)}^\intercal)_{kj} \bigodot M
のように系列長×次元数のサイズの行列の積の形で書くことができます。
$B_{ik}$は以前の記事で提案した対数畳み込みを用いて計算できます。
実装
整理してから出します