3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

数論変換(NTT)による多項式の掛け算の実装 in C++

Last updated at Posted at 2022-11-30

こんにちは。株式会社オプティマインドの伊豆原と申します。

当社の競プロ部の活動の一環として、今回は数論変換(Number Theoric Transform。以下NTT)に基づいた多項式の掛け算の、C++での実装について書きたいと思います。

というのも以前、PyPyによる実装に関する以下の記事を書いておりまして、その中でk-reductionと呼ばれる特定の素数による剰余演算を高速化する手法について触れました。残念ながらPyPyでは遅くなってしまったので、C++のようなコンパイル言語ならどうなるだろうか?という短い検証記事になります。

k-reduction

k-reductionについて改めて紹介します。

k-reductionとはプロス素数と呼ばれる$p = k\cdot 2^m+1$の形をした素数(ここで$k$は奇数)を法とした剰余計算を、ビット演算で高速に行う手法(※1)です。下記の論文で提案されました。

※1: 正確には、十分小さい正の整数$k\geq 3,l\geq 1$に対する$p = k \cdot 2 ^m \pm l$を法とした剰余計算で使えるようです。

登場する2つの関数K-REDとK-RED-2xは数式で書きますと以下のようになり、

$ \text{K-RED}(C) := k \cdot (C \bmod 2^m) - (C \ / \ 2^m)$
$ \text{K-RED-2x}(C) := k^2 \cdot (C \bmod 2^m) - k \cdot (C \ / \ 2^m \bmod 2^m) + (C \ / \ 2^{2m})$

それぞれ$\text{K-RED}(C) \equiv k\cdot C \bmod p$および$\text{K-RED-2x}(C) \equiv k^2\cdot C \bmod p$となります。剰余そのものでなく係数付きなのがミソですね。

C++でビット演算を使って書くと以下のようになります。

// p = k*2^m+1による剰余の場合
using ll = long long;
ll mask = (1 << m) - 1;
ll mask2 = (1 << (2 * m)) - 1;

ll k_red(ll c){
    return k * (c & mask) - (c >> m));
}

ll k_red_2x(ll c){
    return k * k * (c & mask) - k * ((c >> m) & mask) + (c / mask2); 
}

ここで注意なのは、$p=k\cdot 2^m + 1$の値によってはK-REDやK-RED-2xによるNTTの計算途中の結果が64bitに収まらない可能性があることです。実際、各関数の大きさを評価しますと、

$ |\text{K-RED}(C)| := k \cdot 2^m + (|C| \ / \ 2^m)$
$ |\text{K-RED-2x}(C)| := k^2 \cdot 2^m + k \cdot 2^m + (|C| \ / \ 2^{2m})$

となります。計算に出てくる配列の成分の絶対値の最大値を$D$とすると、後述の実装ではK-REDに$2D$、K-RED-2Xに$Dp$程度の整数が渡されますので、この$Dp$が64bitに収まること、および下記の不等式を満たす必要があります。
$ k \cdot 2^m + 2D \ / \ 2^m \leq D$
$ k^2 \cdot 2^m + k \cdot 2^m + Dp \ / \ 2^{2m} \leq D$
不等式を(雑に)解きますと$k(k+1)\cdot 2^m \cdot \dfrac{2^m}{2^m-k} \leq D$となり、まぁだいたい$D\approx (k+1)p$ぐらいなら問題ないという話になりますが、残念ながら競プロでよく見る$p=998244353$のときの$k=119,m=23$ですと$Dp=(k+1)p^2=23915802919113326616 > 2^{64}$となり64bitを超えてしまいます。

実際、後述のNTTの実装を$p=998244353$で使用すると正しく計算できないことが実験で確かめられます。よってもう少し小さい素数を使用する必要があり、$p=167772161$ですと$k=5,m=25$で$(k+1)p^2=168884988039659526 < 2^{63}$となるのでk-reductionが問題なく使えることが分かります。よって本記事では$p=167772161$による剰余でベンチマークを取ります。

ベンチマークにおける多項式の次数$N$としては前回記事と同様に$N=2\times10^5$と$N=1 \times 10^6$の2パターンを使います。計測に使用した環境も同じくMacBook Pro(2.0 GHz クアッドコア Intel Core i5, 16 GB RAM)です。また、C++のコンパイラおよびコンパイルオプションはg++ -O2 --std=c++17を使用します。

2-バタフライ版

PyPy版の記事でも紹介しました2-バタフライ版の実装をC++に置き換えたものになります。

#include <iostream>
#include <vector>
#include <algorithm>

using ll = long long;

template <typename T>
constexpr T powMod(T p, T n, T m) {
    T res = 1;
    while (n) {
        if (n & 1)
            res = (res * p) % m;
        p = (p * p) % m;
        n >>= 1;
    }
    return (res + m) % m;
}

template <typename T>
const int bitLength(T i) {
    int res = 0;
    while (i) {
        i >>= 1;
        res++;
    }
    return res;
}

// 計算に計算する定数
constexpr ll MOD = 167772161; // K * (2**M) + 1
constexpr ll K = 5;
constexpr ll M = 25;
constexpr ll Q = 17; // 1の(2**M)冪乗根

class NTT {
public:
    std::vector<ll> ws = std::vector<ll>(M + 1);
    std::vector<ll> iws = std::vector<ll>(M + 1);

    NTT() {
        for (ll i = 0; i < M + 1; i++) {
            ws[i] = powMod(Q, 1ll << (M - i), MOD);
            iws[i] = powMod(ws[i], MOD - 2, MOD);
        }
    };
    // k-reductionを使った関数でオーバーライド予定
    virtual void ntt(std::vector<ll>& A) {
        if (A.size() == 1)
            return;
        int n = A.size();
        int k = bitLength(n - 1);
        int r = 1 << (k - 1);
        for (int m = k; m > 0; m--) {
            for (int l = 0; l < n; l += 2 * r) {
                ll wi = 1;
                for (int i = 0; i < r; i++) {
                    ll temp = (A[l + i] + A[l + i + r]) % MOD;
                    A[l + i + r] = (A[l + i] - A[l + i + r]) * wi % MOD;
                    A[l + i] = temp;
                    wi = wi * ws[m] % MOD;
                }
            }
            r >>= 1;
        }
    }
    virtual void intt(std::vector<ll>& A) {
        if (A.size() == 1)
            return;
        ll n = A.size();
        int k = bitLength(n - 1);
        int r = 1;
        for (int m = 1; m < k + 1; m++) {
            for (int l = 0; l < n; l += 2 * r) {
                ll wi = 1;
                for (int i = 0; i < r; i++) {
                    ll temp = (A[l + i] + A[l + i + r] * wi) % MOD;
                    A[l + i + r] = (A[l + i] - A[l + i + r] * wi) % MOD;
                    A[l + i] = temp;
                    wi = wi * iws[m] % MOD;
                }
            }
            r <<= 1;
        }
        ll ni = powMod(n, MOD - 2, MOD);
        for (int i = 0; i < n; i++) {
            A[i] = A[i] * ni % MOD;
        }
    }
    void polymul(std::vector<ll>& f, std::vector<ll>& g) {
        int m = f.size() + g.size() - 1;
        int n = 1 << bitLength(m - 1);
        for (int i = 0; i < f.size(); i++) {
            f[i] %= MOD;
        }
        for (int i = 0; i < g.size(); i++) {
            g[i] %= MOD;
        }
        f.resize(n, 0);
        g.resize(n, 0);
        ntt(f);
        ntt(g);
        for (int i = 0; i < n; i++) {
            f[i] = f[i] * g[i] % MOD;
        }
        intt(f);
        for (int i = 0; i < n; i++) {
            f[i] = (f[i] + MOD) % MOD;
        }
    }
};

計測結果は以下になりました。PyPyよりも速いのはさすがですね。

  • $N=2\times 10^5$ : 0.0900144s
  • $N=1 \times 10^6$ : 0.395433s

k-reduction版

前回記事と同様に、とりあえずk-reductionが使えるように事前調整などを行っています。

// k-reductionで使用する追加定数
constexpr long long IK = powMod(K, MOD - 2, MOD);
constexpr long long IK2 = (IK * IK) % MOD;
constexpr long long K2 = (K * K) % MOD;
constexpr long long mask = (1 << M) - 1;

class NTTRed : public NTT {
public:

    std::vector<ll> wsik2 = std::vector<ll>(M + 1);
    std::vector<ll> iwsik2 = std::vector<ll>(M + 1);

    NTTRed() {
        for (ll i = 0; i < M + 1; i++) {
            ws[i] = powMod(Q, 1ll << (M - i), MOD);
            wsik2[i] = ws[i] * IK2 % MOD;
            iws[i] = powMod(ws[i], MOD - 2, MOD);
            iwsik2[i] = iws[i] * IK2 % MOD;
        }
    };
    const long long k_red(const long long c) {
        return K * (c & mask) - (c >> M);
    }
    const long long k_red_2x(const long long c) {
        return (K2 * (c & mask)) - (K * ((c >> M) & mask)) + (c >> (M * 2));
    }
    void ntt(std::vector<ll>& A) {
        if (A.size() == 1) return;
        int n = A.size();
        int k = bitLength(n - 1);
        int r = 1 << (k - 1);
        ll kik = powMod(IK, (long long)k, MOD);
        for (int i = 0; i < n; i++) {
            A[i] = A[i] * kik % MOD;
        }
        for (int m = k; m > 0; m--) {
            for (int l = 0; l < n; l += 2 * r) {
                ll wi = IK;
                for (int i = 0; i < r; i++) {
                    ll temp = k_red(A[l + i] + A[l + i + r]);
                    A[l + i + r] = k_red_2x((A[l + i] - A[l + i + r]) * wi);
                    A[l + i] = temp;
                    wi = k_red_2x(wi * wsik2[m]);
                }
            }
            r >>= 1;
        }
    }
    void intt(std::vector<ll>& A) {
        if (A.size() == 1) return;
        ll n = A.size();
        int k = bitLength(n - 1);
        int r = 1;
        ll kik = powMod(IK, (long long)k, MOD);
        for (int i = 0; i < n; i++) {
            A[i] = A[i] * kik % MOD;
        }
        for (int m = 1; m < k + 1; m++) {
            for (int l = 0; l < n; l += 2 * r) {
                ll wi = IK2;
                for (int i = 0; i < r; i++) {
                    ll temp = k_red_2x(A[l + i + r] * wi);
                    A[l + i + r] = k_red(A[l + i] - temp);
                    A[l + i] = k_red(A[l + i] + temp);
                    wi = k_red_2x(wi * iwsik2[m]);
                }
            }
            r <<= 1;
        }
        ll ni = powMod(n, MOD - 2, MOD);
        for (int i = 0; i < n; i++) {
            A[i] = A[i] * ni % MOD;
        }
    }
};

計測結果は以下になりました。通常の剰余計算より17%程度速くなっています!

  • $N=2\times 10^5$ : 0.0742771s
  • $N=1 \times 10^6$ : 0.324050s

まとめ

PyPyでの結果も含めて表にまとめてみます。

$2\times10^5$ $1\times 10^6$
recursive(PyPy) 0.73619427142512 3.1136469834751552
bit-reverse(PyPy) 0.179728750487493s 0.80814964076244s
2-butterfly(PyPy) 0.13946093189997555s 0.5823669472120855s
k-reduction(PyPy) 0.26006694413736114s 1.1944619359746866s
2-butterfly(C++) 0.0900144s 0.395433s
k-reduction(C++) 0.0742771s 0.324050s

C++ではK-Reductionを使ってNTTが速くなることが確認できました。法が998244353では整数サイズ(64bit)の関係で使えないのが惜しいですね。

読んで頂きありがとうございました!

3
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?