筆者はレート800前後の茶~緑コーダ
ABC399Fの問題を解いてみる
コード
コーディングしてみたけどうまくオーダーを減らせなかったのでGPTにぶん投げた
プロンプト
2重ループを解消したい
from bisect import bisect_left, bisect_right, insort_left, insort_right
from collections import defaultdict, Counter, deque
from functools import reduce, lru_cache
from itertools import product, accumulate, groupby, combinations
import sys
import os
def rI(): return int(sys.stdin.readline().rstrip())
def rLI(): return list(map(int,sys.stdin.readline().rstrip().split()))
def rI1(): return (int(sys.stdin.readline().rstrip())-1)
def rLI1(): return list(map(lambda a:int(a)-1,sys.stdin.readline().rstrip().split()))
def rS(): return sys.stdin.readline().rstrip()
def rLS(): return list(sys.stdin.readline().rstrip().split())
IS_LOCAL = int(os.getenv("ATCODER", "0"))==0
err = (lambda *args, **kwargs: print(*args, **kwargs, file=sys.stderr)) if IS_LOCAL else (lambda *args, **kwargs: None)
class ModInt:
def __init__(self, x, mod = 998244353):
self.mod = mod
self.x = x.x if isinstance(x, ModInt) else x % self.mod
__str__ = lambda self:str(self.x)
__repr__ = __str__
__int__ = lambda self: self.x
__index__ = __int__
__add__ = lambda self, other: ModInt(self.x + ModInt(other).x)
__sub__ = lambda self, other: ModInt(self.x - ModInt(other).x)
__mul__ = lambda self, other: ModInt(self.x * ModInt(other).x)
__pow__ = lambda self, other: ModInt(pow(self.x, ModInt(other).x, self.mod))
__truediv__ = lambda self, other: ModInt(self.x * pow(ModInt(other).x, self.mod - 2, self.mod))
__floordiv__ = lambda self, other: ModInt(self.x // ModInt(other).x)
__lt__ = lambda self, other: self.x < ModInt(other).x
__gt__ = lambda self, other: self.x > ModInt(other).x
__le__ = lambda self, other: self.x <= ModInt(other).x
__ge__ = lambda self, other: self.x >= ModInt(other).x
__eq__ = lambda self, other: self.x == ModInt(other).x
__ne__ = lambda self, other: self.x!= ModInt(other).x
class mint(ModInt):
pass
M = 998244353
def main():
N, k = rLI()
A = rLI()
s = [0] + list(accumulate(A))
err(s)
ans = mint(0)
# ans = mint(sum(pow(a, k, M) for a in A))
for l in range(N + 1):
for r in range(l + 1, N + 1):
# err(s[r] - s[l])
ans += pow(s[r] - s[l], k, M)
print(ans)
if __name__ == '__main__':
main()
AIの回答
考え方
もともと求めたい値は
$$
\text{ans}=\sum_{0\le l < r\le N}(s[r]-s[l])^k
$$
となっています。
ここで、$s[i]$ は先頭からの累積和($s[0]=0$)です。
(1) 二項展開を利用する
$$
(s[r]-s[l])^k = \sum_{p=0}^{k} \binom{k}{p} (-1)^p (s[r])^{k-p} ,(s[l])^p.
$$
これにより
$$
\text{ans}=\sum_{r=1}^{N} \sum_{l=0}^{r-1} \sum_{p=0}^{k} \binom{k}{p} (-1)^p (s[r])^{k-p} ,(s[l])^p.
$$
順序を入れ替えれば、
各 $r$ について
$$
T_r = \sum_{p=0}^{k} \binom{k}{p} (-1)^p ,(s[r])^{k-p} \left(\sum_{l=0}^{r-1}(s[l])^p \right).
$$
となり、
「内側の $l$ に対する和」を、各べき$p$について累積的に保持することで、
2重ループの内側のループを O(1)(べきの数 k+1 だけの定数回の処理)に置き換えることができます。
(2) 実装のポイント
-
入力値から $s$ (累積和)をまず計算します。
※ 元の値が大きい場合でも、最終的に $\bmod\ M$ を取るので途中も mod M で扱っても問題ありません。 -
各 $r$ で、
・ $(s[r])^i$ ($i=0,\dots,k$)を順次計算してリスト P に保存します。
・ 内側で $p=0,\dots,k$ について、
term = $\binom{k}{p} \times (-1)^p \times P[k-p] \times \text{cum}[p]$
として足し合わせ、
$T_r$ を得ます。
・ その後、累積配列
$\text{cum}[p]$(これまでの s[l] の $p$ 乗の和:l=0~r)に対して、
$P[p]$ を加え更新します。 -
ここで cum[0] は、初期状態で s[0] の 0 乗=1 になるように初期化しておきます。
(その他、s[0] は 0 なので p≧1 では 0 です。)
これにより各 r 毎の計算は O(k) となり、全体では O(N·k) で済みます。
サンプルコード
from sys import stdin, stdout
from itertools import accumulate
M = 998244353
def main():
import sys
input = sys.stdin.readline
# 入力読み込み
N, k = map(int, input().split())
A = list(map(int, input().split()))
# 累積和 s[0]=0, s[i] = A[0]+...+A[i-1] (mod M)
S = [0]*(N+1)
for i in range(N):
S[i+1] = (S[i] + A[i]) % M
# kが0の場合も含めて、kの二項係数を前計算 (binom[p] = C(k, p), p=0..k)
fact = [1]*(k+1)
for i in range(1, k+1):
fact[i] = fact[i-1]*i % M
invfact = [1]*(k+1)
invfact[k] = pow(fact[k], M-2, M)
for i in range(k, 0, -1):
invfact[i-1] = invfact[i]*i % M
binom = [ (fact[k] * invfact[p] % M)*invfact[k-p] % M for p in range(k+1) ]
# 累積和の累積:cum[p] = sum_{l=0}^{r-1} (S[l])^p mod M
# 初期値: S[0]=0ならば、0^0 = 1, 0^p = 0 (p>=1)
cum = [0]*(k+1)
cum[0] = 1 # 0の0乗は1
ans = 0
# r = 1~Nに対して、s[r]とこれまでの S[l] (l < r) で値を足し合わせる
for r in range(1, N+1):
# S[r] の各累乗値を前計算 P[i] = (S[r])^i (i = 0, ..., k)
P = [1]*(k+1)
for i in range(1, k+1):
P[i] = (P[i-1] * S[r]) % M
term_r = 0
# p=0,...,k に対して内和をとる:
# (S[r]-S[l])^k = sum_{p=0}^{k} binom[k][p] * (-1)^p * (S[r])^(k-p) * (S[l])^p
# ここで、cum[p] = sum_{l=0}^{r-1} (S[l])^p
for p in range(0, k+1):
# (-1)^p mod M は、pが奇数なら M-1, 偶数なら 1
sign = 1 if p % 2 == 0 else (M - 1)
# P[k-p] = (S[r])^(k-p)
term = binom[p] * sign % M
term = term * P[k-p] % M
term = term * cum[p] % M
term_r = (term_r + term) % M
ans = (ans + term_r) % M
# 今回の S[r] に対して、累乗値 P[i] を累積配列に加算
for p in range(0, k+1):
cum[p] = (cum[p] + P[p]) % M
stdout.write(str(ans))
if __name__ == '__main__':
main()
解説
-
累積和の計算
元のリスト A から累積和 S を計算し、各 S[i] を mod M で扱います。
これにより、 ある区間の和は S[r] - S[l] として簡単に得られます。 -
二項展開の利用
$(s[r]-s[l])^k$ を
$$
\sum_{p=0}^{k}\binom{k}{p}(-1)^p, (s[r])^{,k-p}(s[l])^p
$$
と展開し、
内側の l に関する総和を各 p について累積しておけば、2重ループの内側は O(1)(正確には O(k))で済みます。 -
累積値の更新
各 r ごとに、これまでの s[l] の各べきの累積値を更新していきます。
この cum 配列を使うことで、l の内側の和の計算を即時にできるようにしています。
この手法は、特に k が小さい場合に非常に有効となります。
もし k が非常に大きい場合は、別の手法(たとえば高速多項式畳み込みなど)を検討する必要がありますが、たいていの問題では k は小さく設定されています。
感想
二項係数使う話は参考になった