概要
Tensor Sketchは複数のベクトルのクロネッカー積の次元削減特徴を、クロネッカー積を実際に計算することなく低次元のまま低コストで計算する手法。日本語の解説記事がなかったため作成。
途中、離散フーリエ変換を利用したベクトル畳み込みをベクトルのアダマール積で計算するテクニックを使用するので、簡単にその内容も解説する。
準備
ベクトルのクロネッカー積
この記事ではベクトルの添字は$0$始まりで記述する。$m$次元のベクトルについて、$\mod m$を使って添字を指定する場面で、添字の範囲が$1,...,m$でなく$0,...,m-1$だった方が便利だからである。
ベクトル$\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^m$を
\boldsymbol{a} = \begin{bmatrix}a_0\\ \vdots \\ a_{m-1} \end{bmatrix}, \boldsymbol{b} = \begin{bmatrix}b_0\\ \vdots \\ b_{m-1} \end{bmatrix}
とすれば、$\boldsymbol{a}, \boldsymbol{b}$のクロネッカー積$\boldsymbol{a} \otimes \boldsymbol{b}$は、
\boldsymbol{a} \otimes \boldsymbol{b} = \begin{bmatrix}a_0 \boldsymbol{b} \\ a_1 \boldsymbol{b} \\ \vdots \\ a_{m-1} \boldsymbol{b} \end{bmatrix} = \begin{bmatrix} a_0 b_0 \\ \vdots \\ a_0 b_{m-1} \\ a_1 b_0 \\ \vdots \\ a_1 b_{m-1} \\ \vdots \\ \vdots \\ a_{m-1} b_0 \\ \vdots \\ a_{m-1} b_{m-1} \end{bmatrix}
となる。$\boldsymbol{a} \otimes \boldsymbol{b} \in \mathbb{R}^{m^2}$である。
Count Sketch
Count Sketchは非常に単純な線形の次元削減法。$m$次元のベクトルの各要素を$d$次元のベクトルのどれかの要素に$\pm1$倍して和の形で入れていく方法である。
\begin{eqnarray}
&& h:[m] \rightarrow [d] \\
&& s: [m] \rightarrow \{-1, +1\} \\
\end{eqnarray}
という任意の関数を用意して、
\{\mathrm{CS}(\boldsymbol{a})\}_k = \sum_{h(j) = k} s(j) a_j
と定義される。ただし$[m]$は集合${0, 1, ..., m-1}$を意味する。
行列で書くと
\{M^{\mathrm{CS}}\}_{k, j} = \begin{cases}
s(j) & \mathrm{if} \, h(j) = k\\
0 & \mathrm{otherwise}
\end{cases}
と定義される行列に対して、
\mathrm{CS}(\boldsymbol{a}) = M^{\mathrm{CS}} \boldsymbol{a}
となる。
Count Sketchは$h$の選び方$d^m$通りと$s$の選び方$2^m$通りの積で$(2d)^m$通りが存在する。全てのCount Sketchからランダムに選択するような場合、Count Sketchは、
\mathbb{E}[\langle \mathrm{CS}(\boldsymbol{a}), \mathrm{CS}(\boldsymbol{b})\rangle] = \langle \boldsymbol{a}, \boldsymbol{b}\rangle
のような内積の期待値が元のベクトル同士の内積になるような性質がある。
Count Sketchのクロネッカー積の空間への拡張
2つのCount Sketchを$\mathrm{CS}_1, \mathrm{CS}_2$とする。それぞれのCount Sketchは$h_1, h_2$と$s_1, s_2$で定義されているとする。
この時、この2つのCount Sketch使って、クロネッカー積の空間におけるCount Sketchを自然に定義することができる。
\begin{eqnarray}
&& h_{1\otimes2}: [m^2] \rightarrow [d] \\
&& s_{1\otimes2}: [m^2] \rightarrow \{-1, +1\} \\
\end{eqnarray}
を
\begin{eqnarray}
&& h_{1\otimes2}(m j_1+j_2) \equiv h_1(j_1) + h_2(j_2) \mod d \\
&& s_{1\otimes2}(m j_1+j_2) = s(j_1) s(j_2)
\end{eqnarray}
と定義して、$h_{1\otimes2}$と$s_{1\otimes2}$で$\mathbb{R}^{m^2}$から$\mathbb{R}^{d}$へ次元を落とすCount Sketch $\mathrm{CS}_{1 \otimes 2}$を定義できる。
これはつまり、$a_{j_1}$と$b_{j_2}$の積が入っている$\boldsymbol{a} \otimes \boldsymbol{b}$の$mj_1+j_2$番目の要素を次元削減時に入れる場所の添字が、元のCount Sketchの和の$h_1(j_1) + h_2(j_2) \mod d$で定義され、係数の$\pm 1$が積の$s(j_1) s(j_2)$で定義されるというものである。
Vector Convolution
通常の関数のConvolution(畳み込み)は
(f*g)(x) = \int_{\mathbb{R}} f(y)g(x-y)dy
のように二つの関数の引数の和が$x$になる範囲で実数全体に引数を動かした積分で定義されている。
離散的かつ有限なベクトルで畳み込みを定義する場合、ベクトルが巡回して前後に続いているとみなした上で、積分でなく和で定義する。つまり、$\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^m$に対して、
\{\boldsymbol{a} * \boldsymbol{b}\}_l = \sum_{\substack{ j+k\equiv l \mod m \\ 0 \le j, k \le m-1}} a_j b_{k}
と定義される。
離散フーリエ変換
$m$次元ベクトルの離散フーリエ変換$\mathcal{F}: \mathbb{R}^m \rightarrow \mathbb{R}^m$は以下のように定義される
\{\mathcal{F}[\boldsymbol{a}]\}_k=\sum_{j=0}^{m-1} e^{-i\frac{2 \pi kj}{m}} a_{j}
ただし、ここでは$i$を虚数単位として使用している。
$\mathcal{F}$は線形変換であり、ベクトルに対する行列積として書ける。つまり、
\{\mathcal{F}\}_{k,j} = e^{-i\frac{2\pi kj}{m}}
の行列で$\boldsymbol{a}$のフーリエ変換は$\mathcal{F} \boldsymbol{a}$となる。
この行列は逆行列$\mathcal{F}^{-1}$が存在し、これによる行列積で逆フーリエ変換を定義できる。$\mathcal{F}^{-1}$の要素は、
\{\mathcal{F}^{-1}\}_{k,j} = \frac{1}{m} e^{i\frac{2\pi kj}{m}}
となる。
Vector Convolutionの離散フーリエ変換
定義から、
\{\mathcal{F} (\boldsymbol{a} * \boldsymbol{b})\}_k = \sum_{j=0}^{m-1} e^{-i\frac{2 \pi kj}{m}} \left( \sum_{\substack{ p+q\equiv j \mod m \\ 0 \le p, q \le m-1}} a_p b_q \right)
である。$p+q\equiv j \mod m$であれば、$e^{-i\frac{2 \pi kj}{m}} = e^{-i\frac{2 \pi k (p+q)}{m}}$であるため、
\{\mathcal{F} (\boldsymbol{a} * \boldsymbol{b})\}_k = \sum_{j=0}^{m-1} \sum_{\substack{ p+q\equiv j \mod m \\ 0 \le p, q \le m-1}} \left( e^{-i\frac{2 \pi k p}{m}} a_p \right) \left( e^{-i\frac{2 \pi k q}{m}} b_q \right)
と変形できる。$0 \le j \le m-1$で$\mod m$において取りうる値の全てであり、$0 \le j \le m-1, p+q\equiv j \mod m$で$p, q$を動かして和を取ると、$0 \le p,q \le m-1$の全ての$p, q$の組み合わせでの和になる。つまり、
\begin{eqnarray}
\{\mathcal{F} (\boldsymbol{a} * \boldsymbol{b})\}_k &=& \sum_{j=0}^{m-1} \sum_{\substack{ p+q\equiv j \mod m \\ 0 \le p, q \le m-1}} \left( e^{-i\frac{2 \pi k p}{m}} a_p \right) \left( e^{-i\frac{2 \pi k q}{m}} b_q \right)\\
&=& \sum_{p=0}^{m-1} \sum_{q=0}^{m-1} \left( e^{-i\frac{2 \pi k p}{m}} a_p \right) \left( e^{-i\frac{2 \pi k q}{m}} b_q \right) \\
&=& \left( \sum_{p=0}^{m-1} e^{-i\frac{2 \pi kp}{m}} a_{p} \right) \left( \sum_{q=0}^{m-1} e^{-i\frac{2 \pi kq}{m}} b_{q} \right) \\
&=& \{\mathcal{F}(\boldsymbol{a})\}_k \{\mathcal{F}(\boldsymbol{b})\}_k
\end{eqnarray}
となる。従って、
\mathcal{F} (\boldsymbol{a} * \boldsymbol{b}) = \mathcal{F} (\boldsymbol{a}) \odot \mathcal{F} (\boldsymbol{b})
となる。だたし、$\odot$はアダマール積(ベクトルの要素ごとの積)である。
これは、通常の関数のフーリエ変換において、2つの関数の畳み込みのフーリエ変換が、それぞれの関数のフーリエ変換の積になることと対応する性質である。
Tensor Sketch
クロネッカー積の空間のCount Sketchの性質
$\mathbb{R}^{m^2}$から$\mathbb{R}^{d}$へ次元を落とすクロネッカー積の空間のCount Sketch $\mathrm{CS}_{1 \otimes 2}$は$0 \le k \le d-1$の添字に対して、
\{\mathrm{CS}_{1 \otimes 2} (\boldsymbol{a} \otimes \boldsymbol{b})\}_k = \sum_{\substack {h_1(j_1) + h_2(j_2) \equiv k \mod d \\ 0 \le j_1, j_2 \le m-1}
} s_1(j_1) s_2(j_2) a_{j_1} b_{j_2}
と定義される。
ここで、$0 \le p,q \le d-1, p+q \equiv k \mod d$を満たす$p, q$を動かす形に和を書き換えて、
\begin{eqnarray}
\{\mathrm{CS}_{1 \otimes 2} (\boldsymbol{a} \otimes \boldsymbol{b})\}_k &=& \sum_{\substack {p+q \equiv k \pmod d \\ 0 \le p, q \le d-1}} \sum_{\substack{h_1(j_1) = p \\ h_2(j_2) = q}} s_1(j_1) s_2(j_2) a_{j_1} b_{j_2} \\
&=& \sum_{\substack {p+q \equiv k \pmod d \\ 0 \le p, q \le d-1}} \left( \sum_{h_1(j_1) = p} s_1(j_1) a_{j_1} \right) \left( \sum_{h_2(j_2) = q} s_2(j_2) b_{j_2} \right) \\
&=& \sum_{\substack {p+q \equiv k \pmod d \\ 0 \le p, q \le d-1}} \{\mathrm{CS}_{1} (\boldsymbol{a})\}_p \{\mathrm{CS}_{2} (\boldsymbol{b})\}_q
\end{eqnarray}
となる。これは2つの$d$次元ベクトル$\mathrm{CS}_1 (\boldsymbol{a}), \mathrm{CS}_2 (\boldsymbol{b})$についてのVector Comvolutionになっている。つまり、
\mathrm{CS}_{1 \otimes 2} (\boldsymbol{a} \otimes \boldsymbol{b}) = \mathrm{CS}_1 (\boldsymbol{a}) * \mathrm{CS}_2 (\boldsymbol{b})
になる。このクロネッカー積の空間のCount Sketchは、各Count SketchのVector Comvolutionであるという性質を用いることで、クロネッカー積の空間のCount Sketchを高速に計算することができる。
フーリエ変換を利用した高速なクロネッカー積の空間のCount Sketchの計算
これがTensor Sketchと呼ばれる手法。
ここまでの議論により、クロネッカー積の空間のCount Sketchを離散フーリエ変換すると、Vector Comvolutionの離散フーリエ変換の性質も用いて、
\mathcal{F} (\mathrm{CS}_{1 \otimes 2} (\boldsymbol{a} \otimes \boldsymbol{b})) = \mathcal{F} (\mathrm{CS}_1 (\boldsymbol{a}) * \mathrm{CS}_2 (\boldsymbol{b})) = \mathcal{F} (\mathrm{CS}_1 (\boldsymbol{a})) \odot \mathcal{F} (\mathrm{CS}_2 (\boldsymbol{b}))
となるから、
\mathrm{CS}_{1 \otimes 2} (\boldsymbol{a} \otimes \boldsymbol{b}) = \mathcal{F}^{-1} (\mathcal{F} (\mathrm{CS}_1 (\boldsymbol{a})) \odot \mathcal{F} (\mathrm{CS}_2 (\boldsymbol{b})))
である。
これにより、クロネッカー積の$m^2$次元ベクトルを$d$次元に次元削減する計算を、$m$次元ベクトルを$d$次元に次元削減する計算2つと、それらの$d$次元離散フーリエ変換、$d$次元のアダマール積、$d$次元離散逆フーリエ変換の合成に変えることができる。
$d$が2の累乗の場合、離散フーリエ変換はFFTを使えるので、$d^2$でなく$d \log d$のオーダーで計算できる。その場合、素朴には計算量のオーダーは$\mathcal{O}(m^2d)$から$\mathcal{O}(md + d \log d)$になる。
$d$に比べて$m$が大きい場合に計算コストは大きく削減される。
3つ以上のベクトルのクロネッカー積への拡張
クロネッカー積の2項演算は、2つのベクトルの次元が同じである必要はない。また、クロネッカー積は結合法則
\boldsymbol{a} \otimes (\boldsymbol{b} \otimes \boldsymbol{c}) = (\boldsymbol{a} \otimes \boldsymbol{b}) \otimes \boldsymbol{c}
が成り立つので、これを$\boldsymbol{a} \otimes \boldsymbol{b} \otimes \boldsymbol{c}$のように2項演算の順序を省略して書いても問題ない。4つ以上のベクトルのクロネッカー積も同様に定義し、記載できる。
$\boldsymbol{a}, \boldsymbol{b}, \boldsymbol{c} \in \mathbb{R}^m$の場合、$\boldsymbol{a} \otimes \boldsymbol{b} \otimes \boldsymbol{c} \in \mathbb{R}^{m^3}$になる。
詳しい導出は省略するが、ここまでと同様の議論で3つ以上のベクトルのTensor Sketchでも、
\mathrm{CS}_{1 \otimes 2 \otimes 3} (\boldsymbol{a} \otimes \boldsymbol{b} \otimes \boldsymbol{c}) = \mathcal{F}^{-1} (\mathcal{F} (\mathrm{CS}_1 (\boldsymbol{a})) \odot \mathcal{F} (\mathrm{CS}_2 (\boldsymbol{b})) \odot \mathcal{F} (\mathrm{CS}_3 (\boldsymbol{c})))
のようにそれぞれのベクトルのCount Sketch、フーリエ変換、アダマール積、逆フーリエ変換で計算できる。$K$個のベクトルのクロネッカー積のCount Sketchの計算の場合、クロネッカー積$\mathbb{R}^{m^K}$から$d$に次元削減する計算について、計算量のオーダーは素朴には定義通りの計算と比べて$\mathcal{O}(m^K d)$から$\mathcal{O} (Kmd + Kd \log d)$になる。$K$がある程度の数になる場合、$m$が$d$よりも大きければ計算コストの減少は顕著になる。