search
LoginSignup
6
Help us understand the problem. What are the problem?

posted at

updated at

オンライン畳み込み

この記事の目的

オンライン畳み込み(Relaxed Convolution 1 または Relaxed Multiplication 2 などとも呼ばれるようです)を $O(N(\log N)^2)$ で処理する方法について書きます 3

畳み込みがオンラインとは、 $A$ および $B$ の各項が前から順に与えられたとき、その都度、畳み込み $C=A*B$ の各項を順に返すことを言います 4。詳細は後述します。

(復習)畳み込み

まずは通常の(オフラインの)畳み込みの復習をします。
数列 $(c_0,\ c_1,\cdots)$ が $(a_0,\ a_1,\cdots)$ と $(b_0,\ b_1,\cdots)$ の畳み込みであるとは、各 $i$ について

$$
c_i = \displaystyle\sum_{j=0}^{i} a_j \times b_{i-j}
$$

が成立することでした。あるいは、べき級数の言葉では

$A(x) = a_0 + a_1 x + a_2 x^2 + \cdots$
$B(x) = b_0 + b_1 x + b_2 x^2 + \cdots$
$C(x) = c_0 + c_1 x + c_2 x^2 + \cdots$

について、 $C(x)=A(x) \times B(x)$ が成立するとも言えます 5 6

$A(x)$ および $B(x)$ の最初の $N$ 項が与えられたとき、 $C(x)$ の最初の $N$ 項を求めることは FFT により $O(N\log N)$ でできます。

オンライン畳み込み

やりたいこと

冒頭で書いた通り、やりたいことは下記です。

$C(x)=A(x) \cdot B(x)$ とします。$q=0,\ 1,\ 2, \cdots,\ N-1$ の順に下記を処理してください。

  • $a_q$ および $b_q$ が与えられるので $c_q$ を返してください。

$c_i = \sum_{j=0}^{i} a_j \times b_{i-j}$ に従って愚直に計算すると $q$ 番目のクエリの計算量が $\Theta(q)$ 、合計で $\Theta(N^2)$ になってしまうので高速で処理するには工夫が必要です。

工夫

$A$ と $B$ の項をいくつかまとめて畳み込みをすると FFT の恩恵が受けられます。
具体的には次の表のようにまとめて計算します 1 7。表の見方は、上から $i$ 番目、左から $j$ 番目のセルを $q_{i,j}$ とすると、$A(x)$ の $i$ 次の項と $B(x)$ の $j$ 次の項の積を $q_{i,j}$ 番目のクエリの際に計算することを示しています。

image.png

次の 2 点が成立することからうまくいくことが分かります。

  • $q_{i,j} \ge \max(i, j)$ が成立する、すなわち畳み込みをする際には必要な係数の情報が揃っている
  • $q_{i,j} \le i+j$ が成立する、すなわち各クエリの処理後には必要な計算が終わっている

例えば、表中で $6$ が出てくるのは次の $5$ つの正方形部分です。
image.png

つまり $q=6$ 番目のクエリを処理する際には、

  • $a_0 \times b_6 x^6$
  • $(a_1 x+a_2 x^2)\times (b_5 x^5+b_6 x^6)$
  • $(a_3 x^3 +a_4 x^4 +a_5 x^5+a_6 x^6)\times (b_3x^3+b_4x^4+b_5x^5+b_6x^6)$
  • $(a_5x^5+a_6x^6)\times (b_1x+b_2x^2)$
  • $a_6x^6 \times b_0$

に対応する部分の畳み込みをすればよいです。

計算量

簡単のため $N=2^K-1$ とすると、一辺が $2^k$ の正方形は $2^{K-k+1}-1$ 個あるので、 FFT を使うとこの部分は $O(2^K \cdot k) \in O(N\log N)$ で処理できます。全体の計算量は $O(N (\log N)^2)$ です。

実装

各クエリでどの大きさの正方形まで処理するかが分かれば良いです。これは $q+2$ の lsb 8 を見ると分かります。ただし $q=2^k-2$ 型のときは最後の大きさの正方形が 1 つだけになり別処理が必要です。

コード

Python によるコードの例を示します。

test.py
P = 998244353
p, g, ig = 998244353, 3, 332748118
W = [pow(g, (p - 1) >> i, p) for i in range(24)]
iW = [pow(ig, (p - 1) >> i, p) for i in range(24)]
 
def convolve(a, b):
    def fft(f):
        for l in range(k, 0, -1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * W[l] % p)
 
            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t]) % p, U[j] * (f[s] - f[t]) % p
 
    def ifft(f):
        for l in range(1, k + 1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * iW[l] % p)
 
            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t] * U[j]) % p, (f[s] - f[t] * U[j]) % p
 
    n0 = len(a) + len(b) - 1
    if len(a) < 50 or len(b) < 50:
        ret = [0] * n0
        if len(a) > len(b): a, b = b, a
        for i, aa in enumerate(a):
            for j, bb in enumerate(b):
                ret[i+j] = (ret[i+j] + aa * bb) % p
        return ret
    
    k = (n0).bit_length()
    n = 1 << k
    a = a + [0] * (n - len(a))
    b = b + [0] * (n - len(b))
    fft(a), fft(b)
    for i in range(n):
        a[i] = a[i] * b[i] % p
    ifft(a)
    invn = pow(n, p - 2, p)
    for i in range(n0):
        a[i] = a[i] * invn % p
    del a[n0:]
    return a
 
class RelaxedMultiplication():
    # h = f * g
    def __init__(self):
        self.f = []
        self.g = []
        self.h = []
        self.n = 0
    
    def calc(self, l1, r1, l2, r2):
        self.h += [0] * (r1 + r2 - 1 - len(self.h))
        for i, a in enumerate(convolve(self.f[l1:r1], self.g[l2:r2]), l1 + l2):
            self.h[i] = (self.h[i] + a) % p
        
    def append(self, a, b):
        self.f.append(a)
        self.g.append(b)
        self.n += 1
        n = self.n
        m = (n + 1) & -(n + 1)
        s = 0
        if m <= n:
            a = 1
            while a <= m:
                self.calc(n - a, n, s, s + a)
                self.calc(s, s + a, n - a, n)
                s += a
                a <<= 1
        else:
            a = 1
            while a < m >> 1:
                self.calc(n - a, n, s, s + a)
                self.calc(s, s + a, n - a, n)
                s += a
                a <<= 1
            self.calc(n - a, n, s, s + a)
        return self.h[n-1]

Semi-Relaxed

$A(x)$ と $B(x)$ のうち片方は最初から与えられているとき Semi-Relaxed と言ったりするようです 2
$B$ 側が固定の場合、先ほど出てきた条件のうち「 $q_{i,j} \ge \max(i, j)$ 」が「 $q_{i,j} \ge i$ 」に緩められます。

  • $q_{i,j} \ge i$ が成立する、すなわち畳み込みをする際には必要な係数の情報が揃っている
  • $q_{i,j} \le i+j$ が成立する、すなわち各クエリの処理後には必要な計算が終わっている

すると右側方向には「フライング」ができるのでより効率よく計算することができます。例えば下記のようにすると良いです 7
オーダーは変わらず $O(N(\log N)^2)$ です。

image.png

関連問題

ABC 230-H
AC コードnoshi91さんの解説 を参考にしています。)

ABC 213-H
TLE コード (定数倍さん・・・)
AC コード (Semi-Relaxed にするとぎりぎり通る)

  1. Relaxed Convolution (中国語読めないけど 2 番目の表を参考にしました。) 2

  2. New algorithms for relaxed multiplication 2

  3. 分割統治 FFT とも呼ばれるようですが、本稿では分割統治の考え方は使っていません。結果的には似たような処理をやっていることになりますが。

  4. 参考文献では Relaxed という用語が使われていますが、「オンライン」の方がイメージしやすいと思うので本稿では同じ意味で使っています。

  5. 右辺は多項式環上での積を表します。畳み込みを定義するだけなら環は何でも良いのですが、本稿では FFT ができることを仮定しています。競技プログラミングでは例えば FFT-Friendly な素数 $p=998244353$ に対して $\mathbb{Z}/p\mathbb{Z}$ 上で考えることが多いです。

  6. 本記事でも一部べき級数の言葉を使って説明します。

  7. 式で書きづらいので表で勘弁してください。 2

  8. Least Significant Bit

Register as a new user and use Qiita more conveniently

  1. You can follow users and tags
  2. you can stock useful information
  3. You can make editorial suggestions for articles
What you can do with signing up
6
Help us understand the problem. What are the problem?