こんにちは。株式会社オプティマインドで最適化チームに所属している伊豆原と申します。
本日は当社の競プロ部の活動の一環として、高速なxor畳み込みについて書かさせていただきます。なお、高速なんたら変換に関しては過去にも以下のような記事を書いておりまして、その第3弾になります。
xor畳み込みとは
$u=(x_0,\cdots,x_{n-1})$と$v=(y_0,\cdots,y_{n-1})$を長さ$n$の実数ベクトルとすると、xor畳み込み$u\ast v$はxor演算を$\oplus$としてそのk番目の成分が
(u \ast v)_k = \sum_{k = i \oplus j} x_i y_j
となるベクトルして定義されます。例えば$n=2$ならば$u\ast v = (x_0 y_0 + x_1 y_1, x_0 y_1 + x_1 y_0)^T$です。
xor畳み込みは愚直に行えば$\mathcal{O}(n^2)$の計算量になりますが、高速Walsh-Hadamard変換という手法を使うことで$\mathcal{O}(n \log n)$で達成することができます。本記事ではこれを紹介したいと思います。
xor畳み込みの行列表現
まずはn=2におけるxor畳み込みを考えてみます。$u=(x_0,x_1)^T$と$v=(y_0,y_1)^T$とのxor畳み込みは、天下り的ですが行列$H$を
H = \dfrac{1}{\sqrt{2}}\begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix}
と定義すれば$\sqrt{2} H(Hu \odot Hv)$として計算できます($\odot$は成分ごとの積)。実際に計算を書き下しますと
Hu = \dfrac{1}{\sqrt{2}} \begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix} \begin{pmatrix} x_0 \\ x_1 \end{pmatrix} =\dfrac{1}{\sqrt{2}}\begin{pmatrix} x_0+x_1 \\ x_0 - x_1\end{pmatrix} \\
Hv = \dfrac{1}{\sqrt{2}} \begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix} \begin{pmatrix} y_0 \\ y_1 \end{pmatrix} =\dfrac{1}{\sqrt{2}}\begin{pmatrix} y_0+y_1 \\ y_0 - y_1\end{pmatrix} \\
\sqrt{2} H(Hu \odot Hv) = \dfrac{1}{2} \begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix} \begin{pmatrix} (x_0+x_1)(y_0+y_1) \\ (x_0-x_1)(y_0-y_1) \end{pmatrix} = \begin{pmatrix} x_0 y_0+x_1 y_1 \\ x_0 y_1 + x_1 y_0 \end{pmatrix}
となります。係数がうまく調整されてxorの演算と一致するわけです。
では一般に$n=2^k$の時はどうなるでしょうか?実は上の行列$H_1=H$から再帰的に
H_{i+1} = \dfrac{1}{\sqrt{2}}\begin{pmatrix} H_i & H_i \\ H_i & -H_i \end{pmatrix}
と定義すれば$2^{k/2}H_k ( H_k u \odot H_k v)$がxor畳み込みになります。
これを帰納法で簡単に説明します。$u,v$を長さ$2^{k+1}$のベクトルとしたとき、それらを長さ$2^k$のベクトル$u_0,u_1,v_0,v_1$に分割します($u=(u_0^T,u_1^T)^T,v=(v_0^T,v_1^T)^T$)。これはちょうどベクトル成分の添字を2進数表記する時、下から(k+1)-bit目が$0$のものと$1$のものに分かれる感じです。この時、計算から
2^{(k+1)/2} H_{k+1}(H_{k+1} u\odot H_{k+1} v) =
\dfrac{2^{k/2}}{2} \begin{pmatrix} H_k & H_k \\ H_k & -H_k \end{pmatrix} \begin{pmatrix} (H_k u_0 + H_k u_1) \odot (H_k v_0 + H_k v_1) \\ (H_k u_0 - H_k u_1) \odot (H_k v_0 - H_k v_1) \end{pmatrix} \\
= 2^{k/2}\begin{pmatrix} H_k (H_ku_0 \odot H_K v_0) + H_k (H_k u_1 \odot H_k v_1) \\ H_k (H_ku_0 \odot H_K v_1) + H_k (H_k u_1 \odot H_k v_0)\end{pmatrix}
= \begin{pmatrix} u_0 \ast v_0 + u_1 \ast v_1 \\ u_0 \ast v_1 + u_1 \ast v_0 \end{pmatrix} = u \ast v
となります。ここで最後の等式は、各$u_i,v_i (i = 0,1)$の添字の(k+1)-bit目が$i$であることから導けます。
以上から、xor畳み込みが行列とベクトルの積によって表現されることが分かりました。なお、本章で現れた$H_k$による掛け算はHadamard変換と呼ばれるものになります。
高速Walsh-Hadamard変換
前章のとおりxor畳み込みは行列とベクトルの積を計算すればよいのですが、このまま計算すると計算量が$\mathcal{O}(n^2)$になってしまいます。この計算量を$\mathcal{O}(n\log n)$に改善するのが高速Walsh-Hadamard変換と呼ばれる方法で、以下のように計算します。
求めたいのは$H_k u$です。$H_k$は$H_{k-1}$から再帰的に定義されていたことを思い出して$H_k u$を分解していきます。$E_k$を$k$次元の単位行列とすれば
H_k u =
\dfrac{1}{\sqrt{2}}
\begin{pmatrix}
E_{k-1} & E_{k-1} \\
E_{k-1} & -E_{k-1}
\end{pmatrix}
\begin{pmatrix}
H_{k-1} & 0 \\
0 & H_{k-1}
\end{pmatrix}u \\
=
\dfrac{1}{2}
\begin{pmatrix}
E_{k-1} & E_{k-1} \\
E_{k-1} & -E_{k-1}
\end{pmatrix}
\begin{pmatrix}
E_{k-2} & E_{k-2} & 0 & 0 \\
E_{k-2} & -E_{k-2} & 0 & 0 \\
0 & 0 & E_{k-2} & E_{k-2} \\
0 & 0 & E_{k-2} & -E_{k-2} \\
\end{pmatrix}
\begin{pmatrix}
H_{k-2} & 0 & 0 & 0 \\
0 & H_{k-2} & 0 & 0 \\
0 & 0 & H_{k-2} & 0 \\
0 & 0 & 0 & H_{k-2} \\
\end{pmatrix}
u = \cdots
この分解では最終的に$H_1$が対角線上に並んだものが現れます。特に、分解に現れる行列は各行に非零な要素を2つしかもちませんので、ベクトルとの掛け算が$\mathcal{O}(n)$で計算できます。分解を右から順に計算していけば、行列×ベクトルの計算が$\mathcal{O}(n)$で、掛け算自体が$\mathcal{O}(\log n)$回あるので計算量は$\mathcal{O}(n\log n)$となります。ソースコードにすると以下のようになります(ただし係数の2の冪乗部分は無視しています)。
# リスト"u"の長さは2のべき乗であることを仮定
def FastWalshHadamardTransform(u):
k = (len(u) - 1).bit_length()
h = 1
for _ in range(k):
for i in range(0,len(u),h*2):
for j in range(i,i+h):
u[j],u[j+h] = u[j]+u[j+h],u[j]-u[j+h]
h *= 2
行われている計算を部分的に図にすると、以下の感じになります(赤矢印が足し算、青矢印が引き算)。
定義したFastWalshHadamardTransform
を使えば、xor畳み込みは以下のように実装できます。
# リスト"u,v"の長さは2のべき乗であることを仮定
def xorConvolution(u,v):
k = (len(X)-1).bit_length()
FastWalshHadamardTransform(X)
FastWalshHadamardTransform(Y)
for i in range(len(X)):
X[i] *= Y[i]
FastWalshHadamardTransform(X)
for i in range(len(X)):
X[i] >>= k # ここで2の冪乗分をまとめて行う
return X
以上で、本記事の主題である高速Hadamard変換によるxor畳み込みが達成できました。
Appendix 1 : Fourier変換としてのHadamard変換
先ほど天下り的に与えた行列$H$を自然に導く方法を書いておきます。これには有限Abel群上のFourier変換を使います。
集合$\{0,1\}$上のxor演算は位数$2$の有限群$G_2 = \mathbb{Z}/2\mathbb{Z}$における足し算と同じものになります。
長さ2のベクトル$u=(x_0,x_1)$は下の定義によって$G_2$上の関数$f_u$とみなせ、長さ2のベクトル全体は関数環$C(G_2)= \{f : G_2 \to \mathbb{C}\}$と同一視できます。
f_u : G_2 \to \mathbb{C}; a \mapsto x_a \in \mathbb{C}
では$\hat{G_2}$を$G_2$の指標群として、$f_u$のFourier変換$F(f_u) \in C(\hat{G_2})$を考えてみます。
F(f_u) : \hat{G_2} \to \mathbb{C} \\
\hat{G_2} \owns \xi \mapsto \dfrac{1}{\sqrt{|G_2|}}\sum_{x \in G_2} \xi (x) f_u(x) \in \mathbb{C}
$G_2$の指標は自明な指標$\chi_0$と冪零な指標$\chi_1(x)=(-1)^x ( x \in G_2=\{0,1\})$の2つなので、$G_2$と同様に関数環$C(\hat{G_2})$は長さ2のベクトル全体と同一視できます。特にこの時、Fourier変換は次元2のベクトル空間間の線形写像とみなせるため、$2\times 2$行列の積として表せることが分かります。
またFourier変換の性質から、$C(G_2)$の2つの関数$f,g$の畳み込み(ベクトルとしてはxor畳み込み)のFourier変換は、それぞれのFourier変換後の関数の積になります(ただし上の定義ですと、係数の関係で定数倍ずれて$F(f\ast g) = \sqrt{2} F(f)\cdot F(g)$になります)。得られた等式の両辺にもう一度Fourier変換を行えば$f\ast g = \sqrt{2} F (F(f)\cdot F(g))$となり、xor畳み込みが行列積と成分ごとの積で表されることが分かります。
ではFourier変換$F$を表す具体的な行列を計算します。関数$F(f_u)\in C(\hat{G_2})$は$\hat{G_2}$上で
\chi_0 \mapsto \dfrac{1}{\sqrt{2}}(f_u(0)+f_u(1)) \\
\chi_1 \mapsto \dfrac{1}{\sqrt{2}}(f_u(0) - f_u(1))
という関数になります。特に$u=(x_0,x_1)^T$の成分を使って書けば$f_u(0) = x_0, f_u(1) = x_1$なので、Fourier変換は下記の行列$H$による積になります。これはHadamard変換そのものですね!
H
\begin{pmatrix}
x_0 \\
x_1
\end{pmatrix}
=
\dfrac{1}{\sqrt{2}}
\begin{pmatrix}
1 & 1 \\
1 & -1 \\
\end{pmatrix}
\begin{pmatrix}
x_0 \\
x_1
\end{pmatrix}
以上、演繹的にxor畳み込みを表す行列$H$を導くことができました。なお、整数$0\sim 2^k-1$とxor演算の成す群と$\mathbb{Z}/2\mathbb{Z}$の$k$個の直和の間の同型を考えれば、一般の$2^k$次元のベクトルに対して同様の方法でxor畳み込みを表す行列を計算できますし、それらは先ほど再帰的に構成した行列と一致します。抽象化は偉大ですね。
Appendix 2 : and/or畳み込みについて
実は行列の形を変えることで、and畳み込みとor畳み込みも実現することができます。実際、2次元の時はand畳み込みならば
G :=
\begin{pmatrix}
1 & 1 \\
0 & 1
\end{pmatrix},
G^{-1} :=
\begin{pmatrix}
1 & -1 \\
0 & 1
\end{pmatrix}
or畳み込みならば
G :=
\begin{pmatrix}
1 & 0 \\
1 & 1
\end{pmatrix},
G^{-1} :=
\begin{pmatrix}
1 & 0 \\
-1 & 1
\end{pmatrix}
とすれば、2つのベクトル$u,v$に対して$G^{-1}(Gu\odot Gv)$がand/or畳み込みになっていることが計算から分かります。$2^k(k>1)$次元に対してもxor畳み込み時と同様に再帰的に行列を構成すれば、これまた同様に帰納法でand/or畳み込みになっていることが分かります。
では、これらの行列演算が何を行っているかを確認してみましょう。and畳み込みでは4次元のベクトル$u = (x_0,x_1,x_2,x_3)^T$に対して$G$を掛けると次のようになります。
Gu=
\begin{pmatrix}
1 & 1 & 1 & 1 \\
0 & 1 & 0 & 1 \\
0 & 0 & 1 & 1 \\
0 & 0 & 0 & 1 \\
\end{pmatrix}
\begin{pmatrix}
x_0 \\
x_1 \\
x_2 \\
x_3 \\
\end{pmatrix}
=
\begin{pmatrix}
x_0 + x_1 + x_2 + x_3\\
x_1 + x_3 \\
x_2 + x_3 \\
x_3 \\
\end{pmatrix}
実はこの行列積は、0〜3の整数をビット表現して要素数2の集合の部分集合全体と捉えた時のゼータ変換と捉えることができます。つまり整数$n$に対して対応する部分集合を$S_n$と書いた時の、下記右辺の計算です。
\begin{pmatrix}
x_0 + x_1 + x_2 + x_3\\
x_1 + x_3 \\
x_2 + x_3 \\
x_3 \\
\end{pmatrix}
=
\begin{pmatrix}
\sum_{\{i|S_0 \subset S_i \}} x_i \\
\sum_{\{i|S_1 \subset S_i \}} x_i \\
\sum_{\{i|S_2 \subset S_i \}} x_i \\
\sum_{\{i|S_3 \subset S_i \}} x_i \\
\end{pmatrix}
同様に$G^{-1}$の積はメビウス変換と捉えることができます。つまり計算の流れとしては、まず$S_i \cup S_j \subseteq S_k$となる$(i,j)$を列挙して足し上げて(これは$S_i \subseteq S_k$となる$i$を2つ並べたものになるので、ゼータ変換してから成分ごとの積を取れば計算できます)、その後でメビウス変換によって$S_i \cup S_j \subsetneq S_k$となる部分を取り除いてることになります。賢いですね。
or畳み込みのほうも(集合の包含関係は逆向きになりますが)同様の流れで考えることができます。
まとめますと、xor畳み込みにはHadamard変換(もといFourier変換)、and/or畳み込みにはゼータ変換及びメビウス変換という異なる背景を見出すことができるのですが、両方とも行列積で表現できるため同様のフレームワークで実装できるというわけです。とても興味深いですね。
参考文献
以下のサイト記事や書籍を参考にさせて頂きました。