LoginSignup
8
0

More than 1 year has passed since last update.

数論変換(NTT)による多項式の掛け算の実装 in PyPy

Last updated at Posted at 2022-06-01

こんにちは。株式会社オプティマインドの伊豆原と申します。

競プロ部の活動の一環として、今回は数論変換(Number Theoric Transform。以下NTT)に基づいた多項式の掛け算の、PyPyでの実装について書きたいと思います。

多項式の掛け算とNTTについて

競技プログラミングの文脈では、多項式の掛け算またはそれに類する畳込みの演算を高速に行う時に高速フーリエ変換(FFT=Fast Fourier Transform)を使うことが多いと思います。FFTに関してはNumpyにはnumpy.fft.fftという関数が実装されてますが、こちらは浮動小数を扱うため、大きい値を係数に持つ多項式の掛け算に使用するには誤差が怖いです。またコンテストによってはPyPyでNumpyを使えないことも多いため、PyPyを主力にしたい場合は自力で実装する必要があります。

そしてどうせ実装するならば、誤差の心配のない整数レベルでのFFTを実装したいです。そこで出てくるのが数論変換(NTT=Number Theoric Transform)です。こちらは適当な1の冪乗根を持つ有限体上の離散フーリエ変換になりますが、FFTと同様の方法で高速に計算することができます。ただしFFTで使う複素数体には任意の1の冪乗根が入っていますが、任意の係数体ではそうはいきません。そのため計算上、長さ$2^m$の配列を変換するには係数体が$1$の$2^m$乗根を持っている必要があるという制限が発生します。

ここでありがたいことに(?)、競技プログラミングで剰余によく出てくる素数$p=998244353 = 119\cdot 2^{23}+1$はプロス素数(proth prime)と呼ばれる$k\cdot 2^m + 1$($k$は奇数)の形をした素数でして、このときフェルマーの小定理から$\mathbb{Z} / p \mathbb{Z}$には1の$2^m=2^{23}$乗根が入っていることが示せます。よってこの係数体$\mathbb{Z}/p\mathbb{Z}$において$2^{23}=8388608$個までの配列なら高速にNTTすることができ、特に$(8388608-1)//2 = 4194303$次同士の多項式の掛け算に使用することが出来ます。これは競技プログラミングの範囲なら十分なサイズと言えるかと思います。


とゆうわけで、$998244353$を法とした多項式の掛け算の、NTTによる実装を色々と試してみたいと思います。ベンチマークにおける多項式の次数$N$としては、コンテストでよく見る$N=2\times10^5$と、$\mathcal{O}(N\log N)$が想定解の時にでてくる$N=1 \times 10^6$の2パターンを使います。計測に使用した環境はMacBook Pro(2.0 GHz クアッドコア Intel Core i5, 16 GB RAM)です。なお、今回実装するアルゴリズムではNTTの適用のため、事前に次数を2の累乗になるまで0-パディングしています。

NTT用のクラスの大枠は以下のような感じで、関数polymul_nttのntt部分とintt(逆数論変換)の実装を色々と試してみます。
※ パフォーマンスの問題で、剰余演算に使用する定数はグローバルに宣言して再代入しないようにします。これはjitに定数だと認識させるためです。

MOD = 998244353 # : 119*2**23+1
K,M,W = 119, 23, 31 # W : 2のM乗根

class NTT:
    def __init__(self):
        # ws[i] = 1の2^i乗根  (31**(2**23) = 1 mod 998244353)
        self.ws = [pow(W,2**i,MOD) for i in range(M,-1,-1)]
        # inverse of ws
        self.iws = [pow(w,MOD-2,MOD) for w in self.ws]
    def polymul_ntt(self,f,g):
        nf = len(f)
        ng = len(g)
        m = nf+ng-1
        n = 2**(m-1).bit_length()
        f = [x % MOD for x in f]+[0]*(n-nf) # 0-padding
        g = [x % MOD for x in g]+[0]*(n-ng) # 0-padding
        self.ntt(f) # 実装予定
        self.ntt(g)
        for i in range(n):
            f[i] = f[i]*g[i]%MOD
        self.intt(f) # 実装予定
        return f[:m]
    def ntt(self, A):
        ....
    def intt(self, A):
        ....

再帰版

NTTの構造的に一番自然なのは再帰で書くことかと思われます。もちろん往々にして再帰関数はパフォーマンス面に不利がある印象(PyPyだと特に)ですが、アルゴリズムが明確であることと、nttinttも係数を変えるだけでほぼ同じ様に書けるという利点があります。

    def _ntt(self, A, tws): # len(A) must be a power of 2
        if len(A) == 1:
            return A
        B0 = self._ntt(A[::2], tws) # 再帰計算
        B1 = self._ntt(A[1::2], tws)
        k = (len(A)-1).bit_length()
        res = [0]*len(A)
        r = 1<<(k-1)
        wi = 1
        w = tws[k]
        for i,(b0,b1) in enumerate(zip(B0,B1)): # バタフライ演算
            res[i] = (b0 + b1 * wi) % MOD
            res[i+r] = (b0 - b1 * wi) % MOD
            wi = (wi*w) % MOD
        return res
    def ntt(self, A):
        A[:] = self._ntt(A, self.ws)
    def intt(self, A):
        ni = pow(len(A), MOD-2, MOD)
        A[:] = [x*ni%MOD for x in  self._ntt(A, self.iws)] # 結果をinplaceに格納

このときの計測時間は以下のようになりました。

  • $N=2\times 10^5$ : 0.73619427142512s
  • $N=1 \times 10^6$ : 3.1136469834751552s

やはり遅いですね。再帰が遅い以前の話として、再帰のたびに発生しているリスト生成コストが重そうです。$N=10^6$の時は競技プログラミング想定で1回の計算でTLEするでしょう。
リスト生成コストなどは引数を工夫することで改善できそうですが、そこで頑張るよりは素直に非再帰にしたいです。

ビットリバース版

上記の再帰版をそのまま非再帰にするのに厄介なのがインデックスの調整です。$n=8$において再帰計算で各インデックスの変数がどの組合せで計算されているかを確認しますと以下のようになってまして、これは実は元のインデックスのビットリバースになっていることが分かります(1=0b001<->0b100=4や3=0b011<->0b110=6など)。なので入力の配列をビットリバースに沿って置換した後にバタフライ演算をしていけば、NTTが非再帰に計算できることが分かります。

image.png

実装ですが、まず添字のビットリバースの結果を格納した配列を準備し(入力サイズが変わっても、適当な個数ずつ飛ばすことで使い回せます)、
※ 実装にはこちらの記事を参考にさせて頂きました => LeetCode / Reverse Bits

    def __init__(self):
        ....
        self.maxK = 21 # 2**maxKが計算したい配列の長さを上回るぐらいの大きさ
        self.maxL = 2 ** self.maxK # 2**21 = 2097152 > 10**6
        def reverse_bits(n):
            n = (n >> 16) | (n << 16)
            n = ((n & 0xff00ff00) >> 8) | ((n & 0x00ff00ff) << 8)
            n = ((n & 0xf0f0f0f0) >> 4) | ((n & 0x0f0f0f0f) << 4)
            n = ((n & 0xcccccccc) >> 2) | ((n & 0x33333333) << 2)
            n = ((n & 0xaaaaaaaa) >> 1) | ((n & 0x55555555) << 1)
            return n
        self.rev = [reverseBits(i) >> (32-self.maxK) for i in range(self.maxL)]

そしてnttinttの冒頭で配列をビットリバースした状態に置換し、そのままinplaceにバタフライ演算を行います。

    def _ntt(self, A, tws):
        n = len(A)
        k = (n-1).bit_length()
        step = self.maxL // (2**k)
        for i,j in enumerate(self.rev[::step]):
            if i < j: # 各ビット反転のペアは2回現れるため、そのうち片方を実行する
                A[i],A[j] = A[j],A[i] # ビットリバースによる置換
        r = 1
        for w in tws[1:k+1]:
            for l in range(0,n,r*2):
                wi = 1
                for i in range(r): # <- range(l,l+r)より速い(?)
                    A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r]*wi)%MOD, (A[l+i]-A[l+i+r]*wi)%MOD
                    wi = (wi*w) % MOD
            r <<= 1
    def ntt(self,A):
        self._ntt(A, self.ws)
    def intt(self,A):
        self._ntt(A, self.iws)
        ni = pow(len(A), MOD-2, MOD)
        for i in range(len(A)):
            A[i] = A[i]*ni%MOD

このときの計測時間は以下のようになりました。

  • $N=2\times 10^5$ : 0.179728750487493s
  • $N=1 \times 10^6$ : 0.80814964076244s

だいぶ実用的な速度になりました。$N=1\times10^6$でも1回の計算ならTLEせずに間に合いそうです。

2-バタフライ版

先ほどのビットリバース版ですが、バタフライ演算の部分をマイナーチェンジしnttinttを分離することとでビットリバース部分を不要にすることができます。なお、再帰版およびビットリバース版で使用していたバタフライ演算をColley-Tukeyバタフライと呼ぶようで、この版ではinttに使用します。そしてntt側に使用しているものはGentleman-Sadeバタフライと呼ぶようです。ここら辺の変換の背景についてはすみません、寡聞にして存じません:bow:
なお、この実装ではnttinttそれぞれを単独で使えなくなることに注意です。

    def ntt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = n.bit_length()-1
        r = 1<<(k-1)
        for w in self.ws[k:0:-1]:
            for l in range(0,n,2*r):
                wi = 1
                for i in range(r): # Gentleman-Sade butterfly
                    A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r])%MOD,(A[l+i]-A[l+i+r])*wi%MOD
                    wi = wi*w%MOD
            r = r//2
    def intt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = (n-1).bit_length()
        r = 1
        for w in self.iws[1:k+1]:
            for l in range(0,n,2*r):
                wi = 1
                for i in range(r): # Colley-Tukey butterfly
                    A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r]*wi)%MOD,(A[l+i]-A[l+i+r]*wi)%MOD
                    wi = wi*w%MOD
            r = r*2
        ni = pow(n, MOD-2, MOD)
        for i in range(n):
            A[i] = A[i]*ni%MOD

このときの計測時間は以下のようになりました。

  • $N=2\times 10^5$ : 0.13946093189997555s
  • $N=1 \times 10^6$ : 0.5823669472120855s

ビットリバース版からさらに速くなりました。ビットリバースを省いただけにしては些か速くなりすぎてる気もしますが、jit様のお気に召されたのでしょうか……

k-reduction版

今までのソースからも分かる通り、NTTの計算には剰余演算が大量に発生します。この剰余演算をそれに相当する高速な演算に置き換えれば一定量のスピードアップが見込めそうです。そこで下記の論文で紹介されているK-REDおよびK-RED-2xというものを使ってみます。

K-REDとK-RED-2xは法となる$p=k\cdot2^m+1$に対して次のように定義されます。定義そのものには剰余演算と割り算がありますが、2の冪乗によるものなのでビットマスクとシフト演算で実現できます。

# p = k * 2**m + 1
def k_red(C):
    return k*(C%2**m) - C//2**m
def k_red_2x(C):
    return k**2*(C%2**m) - k*((C//2**m)%2**m) + C//2**(2*m)

k-redによる計算では剰余の値そのものではなく、剰余の値の$k$倍が計算されます(つまり$k\cdot c \equiv_p \text{k-red}(c)$。負値の時もあり)。k-red-2xでは剰余の値の$k^2$倍が計算されますので、元の配列や掛け算に出てくる1の冪乗根にうまいことその差分を吸収するような前処理が必要になります。

なお、厳密な算出はしておりませんが、k-red等の使用により剰余を取らなくても計算中の各$A[i]$の絶対値は$k\cdot p$程度に抑えられるようです。ここで残念な話ですが、$998244353=119\cdot 2^{23}+1$では$A[i]\cdot \omega ^j$の掛け算などが最大で$k \cdot p^2 = 119 \cdot 998244353^2 = 118582522807270244471 > 2^{63}-1$と64bitを超えるため、PyPyでは計算が極端に遅くなってしまいます。なので$A[i]\cdot \omega^j$の計算にk_red_2xを使用したいなら、より小さい素数(例えば$167772161=5\cdot 2^{25}+1$の元でNTTを行う必要があります。

以上のことを踏まえ、$p=167772161$で実装して計測したいと思います。特に論文の実装に準拠はせず、とりあえずk-redなどが使えるように実装しました。まずk-redに使用する定数や、k-redによる定数倍を吸収するための配列を準備し、

MOD = 167772161 # : 5*2**25+1
K,M,P = 5,25,17
PK2 = K**2%MOD
IK = pow(K,MOD-2,MOD)
IK2 = IK**2%MOD
M2 = M*2
MASK = 2**M-1

class NTT:
    def __init__(self):
        ...
        self.ws2 = [w*IK**2%MOD for w in self.ws]
        self.iws2 = [w*IK*IK%MOD for w in self.iws]

nttintt部分を以下のように実装してみます。

    def ntt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = n.bit_length()-1
        r = 1<<(k-1)

        # k-redによるk倍を吸収する処理
        kik = pow(IK,k,MOD)
        for i in range(n):
            A[i] = A[i]*kik%MOD

        for w in self.ws2[k:0:-1]:
            for l in range(0,n,2*r):
                wi = IK # k-redとk-red-2xの差分を吸収する処理
                for i in range(r):
                    s,t,nwi = A[l+i]+A[l+i+r], (A[l+i]-A[l+i+r])*wi, wi*w
                    A[l+i] = K*(s&MASK)-(s>>M) # k-red
                    A[l+i+r] = PK2*(t&MASK)-K*((t>>M)&MASK)+(t>>M2) # k-red-2x
                    wi = PK2*(nwi&MASK)-K*((nwi>>M)&MASK)+(nwi>>M2) # k-red-2x
            r = r//2
    def intt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = (n-1).bit_length()

        # k-red-2xによるk**2倍を吸収する処理
        kik = pow(IK2,k,MOD)
        for i in range(n):
            A[i] = A[i]*kik%MOD

        r = 1
        for w in self.iws2[1:k+1]:
            for l in range(0,n,2*r):
                wi = 1
                for i in range(r):
                    s,t,nwi = A[l+i]+A[l+i+r]*wi, A[l+i]-A[l+i+r]*wi, wi*w
                    A[l+i] = PK2*(s&MASK)-K*((s>>M)&MASK)+(s>>M2) # k-red-2x
                    A[l+i+r] = PK2*(t&MASK)-K*((t>>M)&MASK)+(t>>M2) # k-red-2x
                    wi = PK2*(nwi&MASK)-K*((nwi>>M)&MASK)+(nwi>>M2) # k-red-2x
            r = r*2
        ni = pow(n, MOD-2, MOD)
        for i in range(n):
            A[i] = A[i]*ni%MOD

手元で確認する感じ、バタフライ演算部に剰余計算が無くともちゃんとNTTが計算できてるようです。なんかそれだけで十分嬉しい。

このとき計測時間は以下のようになりました。

  • $N=2\times 10^5$ : 0.26006694413736114s
  • $N=1 \times 10^6$ : 1.1944619359746866s

残念ながら遅くなってしまいました。有効に活用しようとするならば、ちゃんと各演算の能率を考慮して実装する必要がありそうです。環境や使用言語によっては非常に有用そうなreductionなので気になる所ですが、今回の記事ではとりあえず紹介に留めさせて頂きます。

まとめ

以下が紹介した手法の計算速度をまとめた比較結果になります。

$2\times10^5$ $1\times 10^6$
recursive 0.73619427142512 3.1136469834751552
bit-reverse 0.179728750487493s 0.80814964076244s
2-butterfly 0.13946093189997555s 0.5823669472120855s
k-reduction 0.26006694413736114s 1.1944619359746866s

本記事で紹介したのは細かい手法の違いでしたが、他にも以下のような実装上の工夫をすればより速くなります。

  • Listのarray化(Python標準ライブラリ"array"を使用)
  • forループのunroll(バタフライ演算が配列上隣り合う要素の時に限って、など)
  • wiのキャッシュ化
  • JITオプションいじり(未検討)

ローカルで上記の実装上の工夫を試しますと$N=1\times10^6$で0.425s程度まで速くすることが出来ました。他にも様々な高速化の余地があると思われますので、また色々と試してみたいと思います。

結び

数論変換(NTT)に基づいた多項式の掛け算の、PyPyでの実装について書かせて頂きました。MOD=998244353をjitに定数として認識させないといけないなど、色々とPyPy特有と思われる落とし穴があるのが難しいですね。
読んで頂きありがとうございました。

8
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
0