筆者はレート800前後の茶~緑コーダ
ABC423のE問題を解いていく
実装コード
解説によると累積和を複数作って計算するのと
セグ木をつかって頑張るやつがあるらしく、後者を採用した(ほぼ写経だけど…)
やってることよくわかんないのでGPTに簡易解説してもらった
B[i].val は A の先頭からの累積和 P[i](ただし P[0]=0)。
各 B[i] の size=1。
seg.prod(l-1, r+1).res は、区間 i,j(l-1 ≤ i < j ≤ r)に対し P[j] - P[i] の総和になります。
これはちょうど 区間 [l, r] に含まれる全部分列の「部分列和」の総和 に等しいです。
別表現:sum_{k=l..r} A[k] * (k-l+1) * (r-k+1)。
もっと詳しく
このプログラムは、
入力された配列 A
と区間 [l, r]
に対して
区間
[l, r]
のすべての部分区間(連続部分列)について、その部分区間和を全部足し合わせた総和
を求めています。
つまり:
$$
\text{答え} = \sum_{i=l}^{r} \sum_{j=i}^{r} \left(\sum_{k=i}^{j} A_k\right)
$$
です。
🧩 具体例で理解する
たとえば
A = [1, 2, 3, 4]
l = 2, r = 3
とします。
-
区間
[2,3]
の部分区間は:[2] = 2
[3] = 3
[2,3] = 2+3=5
これらを全部足すと
→ 2 + 3 + 5 = 10
これが出力される答えです。
🧮 式の変形で見るともっとスッキリ
上の二重和(部分区間の総和)は、少し変形すると次のように書けます。
$$
\sum_{k=l}^{r} A_k \times (\text{kを含む部分区間の数})
$$
💭 kを含む部分区間の数とは?
ある要素 A[k]
が含まれる部分区間 [i, j]
とは:
-
i
はl
からk
まで選べる →(k-l+1)
通り -
j
はk
からr
まで選べる →(r-k+1)
通り
したがって:
$$
(\text{kを含む部分区間の数}) = (k-l+1) \times (r-k+1)
$$
🌟 よって、最終的に求めている値は:
$$
\boxed{
\text{答え} = \sum_{k=l}^{r} A_k \times (k-l+1) \times (r-k+1)
}
$$
⚙️ コードの中でどうやって計算しているか
この式を高速に計算するために、
セグメント木(segtree
)を使って次のような工夫をしています。
1️⃣ 構造体 S
の意味
class S:
def __init__(self, val=0, size=0, res=0):
self.val = val # 区間の合計値(累積和)
self.size = size # 区間の長さ(要素数)
self.res = res # 部分区間和の総和(答え)
2️⃣ op(l, r)
の意味
2つの区間 l
, r
を結合する時:
res.val = l.val + r.val
res.size = l.size + r.size
res.res = l.res + r.res + l.size*r.val - l.val*r.size
この式は、左右の区間を合わせた時の「部分区間和の合成則」です。
詳しく見ると:
-
l.res + r.res
→ 左右それぞれの区間内部の部分区間の総和。 -
l.size*r.val - l.val*r.size
→ 左右をまたぐ部分区間の寄与(左端は左区間から、右端は右区間から取る部分)。
これで全体の res.res
が一気に計算できます。
3️⃣ セグ木の目的
- 各点に
S
を持つことで、
区間[l, r]
に対してO(log N)
で合成(prod
)できる -
N
とQ
が大きくても(10⁵〜10⁶)でも高速
⚡ 動作まとめ
ステップ | 処理内容 |
---|---|
入力 |
N, Q, A を読み込む |
累積情報 B[i]
|
Aの先頭からi番目までの和を格納したS 構造体配列を作る |
セグ木構築 |
B をもとに区間情報を保持 |
各クエリ |
[l, r] に対して seg.prod(l-1, r+1).res を計算 |
出力 | 「部分区間和の総和」を出力 |
💡 まとめ(簡単な日本語で)
-
A
の中で[l, r]
の範囲を取る。 - その範囲内でできるすべての連続部分列の和を、ぜんぶ足す。
- 結果を
O(log N)
で高速に出している。
🔢 直感的なイメージ
例として A = [1, 2, 3]
のとき:
部分区間 | 和 |
---|---|
[1] | 1 |
[2] | 2 |
[3] | 3 |
[1,2] | 3 |
[2,3] | 5 |
[1,2,3] | 6 |
合計 | 20 |
→ 出力:20
セグ木の実装は以下を使用する
from bisect import bisect_left, bisect_right, insort_left, insort_right
from collections import defaultdict, Counter, deque
from copy import deepcopy
from functools import reduce, lru_cache
from itertools import product, accumulate, groupby, combinations
import sys
import os
def rI(): return int(sys.stdin.readline().rstrip())
def rLI(): return list(map(int,sys.stdin.readline().rstrip().split()))
def rI1(): return (int(sys.stdin.readline().rstrip())-1)
def rLI1(): return list(map(lambda a:int(a)-1,sys.stdin.readline().rstrip().split()))
def rS(): return sys.stdin.readline().rstrip()
def rLS(): return list(sys.stdin.readline().rstrip().split())
IS_LOCAL = int(os.getenv("ATCODER", "0"))==0
err = (lambda *args, **kwargs: print(*args, **kwargs, file=sys.stderr)) if IS_LOCAL else (lambda *args, **kwargs: None)
class segtree():
n=1
size=1
log=2
d=[0]
op=None
e=10**15
def __init__(self,V,OP,E):
self.n=len(V)
self.op=OP
self.e=E
self.log=(self.n-1).bit_length()
self.size=1<<self.log
self.d=[E for i in range(2*self.size)]
for i in range(self.n):
self.d[self.size+i]=V[i]
for i in range(self.size-1,0,-1):
self.update(i)
def set(self,p,x):
assert 0<=p and p<self.n
p+=self.size
self.d[p]=x
for i in range(1,self.log+1):
self.update(p>>i)
def get(self,p):
assert 0<=p and p<self.n
return self.d[p+self.size]
def prod(self,l,r):
assert 0<=l and l<=r and r<=self.n
sml=self.e
smr=self.e
l+=self.size
r+=self.size
while(l<r):
if (l&1):
sml=self.op(sml,self.d[l])
l+=1
if (r&1):
smr=self.op(self.d[r-1],smr)
r-=1
l>>=1
r>>=1
return self.op(sml,smr)
def all_prod(self):
return self.d[1]
def max_right(self,l,f):
assert 0<=l and l<=self.n
assert f(self.e)
if l==self.n:
return self.n
l+=self.size
sm=self.e
while(1):
while(l%2==0):
l>>=1
if not(f(self.op(sm,self.d[l]))):
while(l<self.size):
l=2*l
if f(self.op(sm,self.d[l])):
sm=self.op(sm,self.d[l])
l+=1
return l-self.size
sm=self.op(sm,self.d[l])
l+=1
if (l&-l)==l:
break
return self.n
def min_left(self,r,f):
assert 0<=r and r<=self.n
assert f(self.e)
if r==0:
return 0
r+=self.size
sm=self.e
while(1):
r-=1
while(r>1 and (r%2)):
r>>=1
if not(f(self.op(self.d[r],sm))):
while(r<self.size):
r=(2*r+1)
if f(self.op(self.d[r],sm)):
sm=self.op(self.d[r],sm)
r-=1
return r+1-self.size
sm=self.op(self.d[r],sm)
if (r& -r)==r:
break
return 0
def update(self,k):
self.d[k]=self.op(self.d[2*k],self.d[2*k+1])
def __str__(self):
return str([self.get(i) for i in range(self.n)])
class S:
def __init__(self,val=0,size=0,res=0):
self.val=val
self.size=size
self.res=res
def op(l,r):
res = S()
res.val=l.val+r.val
res.size=l.size+r.size
res.res=l.res+r.res
res.res+=l.size*r.val
res.res-=l.val*r.size
# err("-"*20)
# err(l.val,l.size,l.res)
# err(r.val,r.size,r.res)
# err(res.val,res.size,res.res)
return res
def main():
N,Q = rLI()
A = rLI()
B = [S() for _ in range(N+1)]
s = S(0,1,0)
B[0] = S(0,1,0)
for i in range(1,N+1):
s.val += A[i-1]
B[i] = deepcopy(s)
# err([b.val for b in B])
seg = segtree(B,op,S(0,0,0))
for _ in range(Q):
l, r = rLI()
ans = seg.prod(l-1,r+1).res
print(ans)
if __name__ == '__main__':
main()
感想
問題見た直後は累積和することまではわかったけどその後が困難を極めた。
解説見た後はそりゃ思いつかんわっていう解法だった。
やっぱりE問題はなかなか攻略できませんなー…
余談:実装するときにセグ木を久々に使ったので使い方ほぼ忘れてたのは内緒