Schönhage-Strassen
Schönhage-Strassen は長い多倍長整数 $A$ と $B$ の乗算
C = AB
を FFT っぽいアルゴリズムを使って計算する方法の 1 つである。整数環で FFT アルゴリズムを使う数論変換(NTT)の 1 つであり、特にものすごく長い桁数(1億桁とか)の多倍長乗算に適用することで計算効率を上げることができる。理由は後述。
ざっくりと手順を説明すると、数を短い多倍長数に分割し、(広義の)フーリエ変換し、項ごとにかけて(広義の)フーリエ逆変換をすることで畳み込み乗算結果を得る。以下、その詳細と共にサンプルっぽい python コードを書いておく。あくまで動作原理の参照用なので、少なくともまともに実行できる範囲では通常の *
演算子を使う方が速いことは断っておく。
def SchonhageStrassen(a, b):
n, k = SetUp(a, b)
m = 1 << (n // 2) + 1
a = NTT(Split(a, n, k), m)
b = NTT(Split(b, n, k), m)
c = MultEach(a, b, m)
return Merge(NTTInv(c, m), k)
def main():
a = 123456789012345678901234567890123456789012345678901234567890123456789
b = 314159653589793238462643383279502884197169399375105820974944592307816
print(SchonhageStrassen(a, b))
print(a * b)
0. 設定
突然の設定になるが、後のフーリエ変換で使う回転子を $w$ とする。多くの離散フーリエ変換の解説で $\zeta$ や $\omega$ と書かれている変数に該当するが、コードとの対応がわかりやすいので $w$ を使う。これの具体的な値としては 2 以上の任意の整数を取って構わないのだが、計算の利便性から $2$、$2^{64}$、$2^{128}$ などが使われやすい(と思う)。ちなみに人の手で変えられる設定は実質的にこの $w$ だけである。 サンプルの python コードでは $w=2$ ということにしている。
これ(の対数)を長さの単位として使って、元々のかける数 $A$、$B$ は $N$ 以下の長さであるという前提にする。
A, B < w^N
1. 数の分割
多倍長整数 $A$ を $w^k$ 進数で表わし、$n$ 項の数列 $\{a_i\}$ に分割する。その際、足りない部分は 0 にする。
<- k unit -> <- k unit -> <- k unit ->
+------------+-----+------------+------------+
| a[n-1] | ... | a[1] | a[0] |
+------------+-----+------------+------------+
<---- 追加分 ---><------ 元々の A の長さ------>
def Split(a, n, k):
mask = (1 << k) - 1
elements = []
for i in range(n):
elements.append(a & mask)
a >>= k
return elements
2. 広義フーリエ変換
適当な剰余環 $\mathbb{N}/m\mathbb{N}$ においてフーリエ変換と同様の変換
\alpha_k \equiv \sum_{j=0}^{n-1}a_j w^{jk} \pmod{m}
\beta_k \equiv \sum_{j=0}^{n-1}b_j w^{jk} \pmod{m}
を計算する。このとき $w$ が原始 $n$ 乗根になるよう、$m=w^{n/2}+1$ としている。ここのルーチンでは FFT アルゴリズムやそれに関連する種々の高速化が適用できるが、ここでは基本的な演算のみ書いておく。
def NTT(a, m, p=1):
n = len(a)
q = n / 2
while q >= 1:
for i in range(q):
w = pow(2, p * i, m)
for j in range(i, n, 2 * q):
k = j + q
a[j], a[k] = (a[j] + a[k]) % m, (a[j] - a[k] + m) * w % m
p, q = p * 2, q / 2
# bit-reverse shuffle
i = 0
for j in range(1, n):
k = n / 2
i ^= k
while i < k:
k = k / 2
i ^= k
if j < i:
a[j], a[i] = a[i], a[j]
return a
3. 項別乗算
これはそのまま項別にかけるだけである。 $0\leq i < n$ において
\gamma_i \equiv \alpha_i \beta_i \pmod{m}
を計算するだけである。が、$\bmod,m$ が曲者で、この乗算結果は負巡回畳み込み乗算なので $n/2$ 単位の長さで離散荷重変換乗算1などを利用することで計算量を抑えることができる。
def MultEach(a, b, m):
for i in range(len(a)):
a[i] = a[i] * b[i] % m
return a
4. 広義フーリエ逆変換
実質的に広義フーリエ変換と同じような変換
c_j \equiv n^{-1} \sum_{k=0}^{n-1}\gamma_k w^{-jk} \pmod{m}
をするだけである。ポイントとしては複素数体ではないので、順方向の変換とのルーチン共有化が面倒だったり、 $w^{-jk}$ を求める部分がちょっと面倒だったりする。ちなみに $n$ は逆数を取れる必要があるので ${\rm GCD}(m, n)=1$ でなければならないが、$w$ に 2 のべき乗を取り、計算に奇数基底の FFT アルゴリズムを使っていなければ問題ない。
def NTTInv(a, m):
n = len(a)
a = NTT(a, m, n - 1)
logn = int(math.log(n, 2))
inv = pow(2, n - logn, m) # inv * n % m = 1
for i in xrange(n):
a[i] = a[i] * inv % m
return a
5. 正規化
ここまでで畳み込み乗算結果がもとまったので、あとは $k$ 単位ずつずらしなおして、重複した部分を素直に足して ($\bmod,p$ の必要がない) いけば積が求まる。
C = \sum_{j=0}^{2n-1} c_jw^{jk}
def Merge(a, k):
val = 0
a.reverse()
for v in a:
val = (val << k) + v
return val
その他の制限
分割の際、漏れる部分があってはいけないので、
2N \leq kn
という条件ができる。$2N$ は積として出てくる数もカバーする大きさである。(負巡回を利用する場合は $N$ になる。詳しくは後でどこかに書く。) 一方で $c_i$ については畳み込み乗算結果が入っているので、法がそれより小さくなってはいけない。数式にすると
c_i < nw^{2k} < m = w^{n/2} + 1
となるが、$\{a_i\}\not\equiv -1$ ということを考えて条件を簡略化すると
\lceil\log_wn\rceil + 2k \leq \frac{n}{2}
\therefore k \leq \frac{1}{2}\left(\frac{n}{2} - \lceil\log_wn\rceil\right)
ということになる。 また、$n$ は FFT アルゴリズムを適用させるために 2 のべき乗を選ぶという前提を合わせると、$n$ に依存して限界となる $N$ や $k$ も決まる。
- $w=2$ の場合(単位:bit)
$n$ | 32 | 64 | 128 | 256 | 512 | 1024 |
---|---|---|---|---|---|---|
$k\leq$ | 5 | 13 | 28 | 60 | 123 | 251 |
$N\leq$ | 80 | 416 | 1792 | 7680 | 31488 | 128512 |
- $w=2^{64}$ の場合(単位:word)
$n$ | 32 | 64 | 128 | 256 | 512 | 1024 |
---|---|---|---|---|---|---|
$k\leq$ | 7 | 15 | 13 | 63 | 127 | 255 |
$N\leq$ | 112 | 480 | 1984 | 8064 | 32512 | 130560 |
def SetUp(a, b):
N = max(a.bit_length(), b.bit_length())
for b in xrange(5, 30):
n = 2 ** b
k = (n / 2 - b) / 2
if n * k >= N * 2:
return n, k
# 2^30 for n is too large
assert false