LoginSignup
15
9

More than 3 years have passed since last update.

Schönhage-Strassen のアルゴリズムメモ

Last updated at Posted at 2016-01-14

Schönhage-Strassen

Schönhage-Strassen は長い多倍長整数 $A$ と $B$ の乗算

C = AB

を FFT っぽいアルゴリズムを使って計算する方法の 1 つである。整数環で FFT アルゴリズムを使う数論変換(NTT)の 1 つであり、特にものすごく長い桁数(1億桁とか)の多倍長乗算に適用することで計算効率を上げることができる。理由は後述。

ざっくりと手順を説明すると、数を短い多倍長数に分割し、(広義の)フーリエ変換し、項ごとにかけて(広義の)フーリエ逆変換をすることで畳み込み乗算結果を得る。以下、その詳細と共にサンプルっぽい python コードを書いておく。あくまで動作原理の参照用なので、少なくともまともに実行できる範囲では通常の * 演算子を使う方が速いことは断っておく。

def SchonhageStrassen(a, b):
  n, k = SetUp(a, b)
  m = 1 << (n // 2) + 1
  a = NTT(Split(a, n, k), m)
  b = NTT(Split(b, n, k), m)
  c = MultEach(a, b, m)
  return Merge(NTTInv(c, m), k)

def main():
  a = 123456789012345678901234567890123456789012345678901234567890123456789
  b = 314159653589793238462643383279502884197169399375105820974944592307816
  print(SchonhageStrassen(a, b))
  print(a * b)

0. 設定

突然の設定になるが、後のフーリエ変換で使う回転子を $w$ とする。多くの離散フーリエ変換の解説で $\zeta$ や $\omega$ と書かれている変数に該当するが、コードとの対応がわかりやすいので $w$ を使う。これの具体的な値としては 2 以上の任意の整数を取って構わないのだが、計算の利便性から $2$、$2^{64}$、$2^{128}$ などが使われやすい(と思う)。ちなみに人の手で変えられる設定は実質的にこの $w$ だけである。 サンプルの python コードでは $w=2$ ということにしている。

これ(の対数)を長さの単位として使って、元々のかける数 $A$、$B$ は $N$ 以下の長さであるという前提にする。

A, B < w^N

1. 数の分割

多倍長整数 $A$ を $w^k$ 進数で表わし、$n$ 項の数列 $\{a_i\}$ に分割する。その際、足りない部分は 0 にする。

 <- k unit ->       <- k unit -> <- k unit ->
+------------+-----+------------+------------+
|   a[n-1]   | ... |    a[1]    |    a[0]    |
+------------+-----+------------+------------+
 <---- 追加分 ---><------ 元々の A の長さ------>
def Split(a, n, k):
  mask = (1 << k) - 1
  elements = []
  for i in range(n):
    elements.append(a & mask)
    a >>= k
  return elements

2. 広義フーリエ変換

適当な剰余環 $\mathbb{N}/m\mathbb{N}$ においてフーリエ変換と同様の変換

\alpha_k \equiv \sum_{j=0}^{n-1}a_j w^{jk} \pmod{m}
\beta_k \equiv \sum_{j=0}^{n-1}b_j w^{jk} \pmod{m}

を計算する。このとき $w$ が原始 $n$ 乗根になるよう、$m=w^{n/2}+1$ としている。ここのルーチンでは FFT アルゴリズムやそれに関連する種々の高速化が適用できるが、ここでは基本的な演算のみ書いておく。

def NTT(a, m, p=1):
  n = len(a)
  q = n / 2
  while q >= 1:
    for i in range(q):
      w = pow(2, p * i, m)
      for j in range(i, n, 2 * q):
        k = j + q
        a[j], a[k] = (a[j] + a[k]) % m, (a[j] - a[k] + m) * w % m
    p, q = p * 2, q / 2

  # bit-reverse shuffle
  i = 0
  for j in range(1, n):
    k = n / 2
    i ^= k
    while i < k:
      k = k / 2
      i ^= k
    if j < i:
      a[j], a[i] = a[i], a[j]

  return a

3. 項別乗算

これはそのまま項別にかけるだけである。 $0\leq i < n$ において

\gamma_i \equiv \alpha_i \beta_i \pmod{m}

を計算するだけである。が、$\bmod\,m$ が曲者で、この乗算結果は負巡回畳み込み乗算なので $n/2$ 単位の長さで離散荷重変換乗算1などを利用することで計算量を抑えることができる。

def MultEach(a, b, m):
  for i in range(len(a)):
    a[i] = a[i] * b[i] % m
  return a

4. 広義フーリエ逆変換

実質的に広義フーリエ変換と同じような変換

c_j \equiv n^{-1} \sum_{k=0}^{n-1}\gamma_k w^{-jk} \pmod{m}

をするだけである。ポイントとしては複素数体ではないので、順方向の変換とのルーチン共有化が面倒だったり、 $w^{-jk}$ を求める部分がちょっと面倒だったりする。ちなみに $n$ は逆数を取れる必要があるので ${\rm GCD}(m, n)=1$ でなければならないが、$w$ に 2 のべき乗を取り、計算に奇数基底の FFT アルゴリズムを使っていなければ問題ない。

def NTTInv(a, m):
  n = len(a)
  a = NTT(a, m, n - 1)

  logn = int(math.log(n, 2))
  inv = pow(2, n - logn, m)  # inv * n % m = 1
  for i in xrange(n):
    a[i] = a[i] * inv % m
  return a

5. 正規化

ここまでで畳み込み乗算結果がもとまったので、あとは $k$ 単位ずつずらしなおして、重複した部分を素直に足して ($\bmod\,p$ の必要がない) いけば積が求まる。

C = \sum_{j=0}^{2n-1} c_jw^{jk}
def Merge(a, k):
  val = 0
  a.reverse()
  for v in a:
    val = (val << k) + v
  return val

その他の制限

分割の際、漏れる部分があってはいけないので、

2N \leq kn

という条件ができる。$2N$ は積として出てくる数もカバーする大きさである。(負巡回を利用する場合は $N$ になる。詳しくは後でどこかに書く。) 一方で $c_i$ については畳み込み乗算結果が入っているので、法がそれより小さくなってはいけない。数式にすると

c_i < nw^{2k} < m = w^{n/2} + 1

となるが、$\{a_i\}\not\equiv -1$ ということを考えて条件を簡略化すると

\lceil\log_wn\rceil + 2k \leq \frac{n}{2}
\therefore k \leq \frac{1}{2}\left(\frac{n}{2} - \lceil\log_wn\rceil\right)

ということになる。 また、$n$ は FFT アルゴリズムを適用させるために 2 のべき乗を選ぶという前提を合わせると、$n$ に依存して限界となる $N$ や $k$ も決まる。

  • $w=2$ の場合(単位:bit)
$n$ 32 64 128 256 512 1024
$k\leq$ 5 13 28 60 123 251
$N\leq$ 80 416 1792 7680 31488 128512
  • $w=2^{64}$ の場合(単位:word)
$n$ 32 64 128 256 512 1024
$k\leq$ 7 15 13 63 127 255
$N\leq$ 112 480 1984 8064 32512 130560
def SetUp(a, b):
  N = max(a.bit_length(), b.bit_length())
  for b in xrange(5, 30):
    n = 2 ** b
    k = (n / 2 - b) / 2
    if n * k >= N * 2:
      return n, k
  # 2^30 for n is too large
  assert false
15
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
9