0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ABC423Eを解いた【累積和とセグ木】

Last updated at Posted at 2025-10-04

筆者はレート800前後の茶~緑コーダ

ABC423のE問題を解いていく

実装コード

解説によると累積和を複数作って計算するのと
セグ木をつかって頑張るやつがあるらしく、後者を採用した(ほぼ写経だけど…)

やってることよくわかんないのでGPTに簡易解説してもらった

B[i].val は A の先頭からの累積和 P[i](ただし P[0]=0)。
各 B[i] の size=1。
seg.prod(l-1, r+1).res は、区間 i,j(l-1 ≤ i < j ≤ r)に対し P[j] - P[i] の総和になります。
これはちょうど 区間 [l, r] に含まれる全部分列の「部分列和」の総和 に等しいです。
別表現:sum_{k=l..r} A[k] * (k-l+1) * (r-k+1)。

もっと詳しく

このプログラムは、
入力された配列 A と区間 [l, r] に対して

区間 [l, r] のすべての部分区間(連続部分列)について、その部分区間和を全部足し合わせた総和

を求めています。

つまり:

$$
\text{答え} = \sum_{i=l}^{r} \sum_{j=i}^{r} \left(\sum_{k=i}^{j} A_k\right)
$$

です。


🧩 具体例で理解する

たとえば

A = [1, 2, 3, 4]
l = 2, r = 3

とします。

  • 区間 [2,3] の部分区間は:

    • [2] = 2
    • [3] = 3
    • [2,3] = 2+3=5

これらを全部足すと
2 + 3 + 5 = 10

これが出力される答えです。


🧮 式の変形で見るともっとスッキリ

上の二重和(部分区間の総和)は、少し変形すると次のように書けます。

$$
\sum_{k=l}^{r} A_k \times (\text{kを含む部分区間の数})
$$


💭 kを含む部分区間の数とは?

ある要素 A[k] が含まれる部分区間 [i, j] とは:

  • il から k まで選べる → (k-l+1) 通り
  • jk から r まで選べる → (r-k+1) 通り

したがって:
$$
(\text{kを含む部分区間の数}) = (k-l+1) \times (r-k+1)
$$


🌟 よって、最終的に求めている値は:

$$
\boxed{
\text{答え} = \sum_{k=l}^{r} A_k \times (k-l+1) \times (r-k+1)
}
$$


⚙️ コードの中でどうやって計算しているか

この式を高速に計算するために、
セグメント木(segtree)を使って次のような工夫をしています。

1️⃣ 構造体 S の意味

class S:
    def __init__(self, val=0, size=0, res=0):
        self.val = val   # 区間の合計値(累積和)
        self.size = size # 区間の長さ(要素数)
        self.res = res   # 部分区間和の総和(答え)

2️⃣ op(l, r) の意味

2つの区間 l, r を結合する時:

res.val  = l.val + r.val
res.size = l.size + r.size
res.res  = l.res + r.res + l.size*r.val - l.val*r.size

この式は、左右の区間を合わせた時の「部分区間和の合成則」です。
詳しく見ると:

  • l.res + r.res
    → 左右それぞれの区間内部の部分区間の総和。
  • l.size*r.val - l.val*r.size
    → 左右をまたぐ部分区間の寄与(左端は左区間から、右端は右区間から取る部分)。

これで全体の res.res が一気に計算できます。


3️⃣ セグ木の目的

  • 各点に S を持つことで、
    区間 [l, r] に対して O(log N) で合成(prod)できる
  • NQ が大きくても(10⁵〜10⁶)でも高速

⚡ 動作まとめ

ステップ 処理内容
入力 N, Q, A を読み込む
累積情報 B[i] Aの先頭からi番目までの和を格納したS構造体配列を作る
セグ木構築 Bをもとに区間情報を保持
各クエリ [l, r] に対して seg.prod(l-1, r+1).res を計算
出力 「部分区間和の総和」を出力

💡 まとめ(簡単な日本語で)

  • A の中で [l, r] の範囲を取る。
  • その範囲内でできるすべての連続部分列の和を、ぜんぶ足す。
  • 結果を O(log N) で高速に出している。

🔢 直感的なイメージ

例として A = [1, 2, 3] のとき:

部分区間
[1] 1
[2] 2
[3] 3
[1,2] 3
[2,3] 5
[1,2,3] 6
合計 20

→ 出力:20


セグ木の実装は以下を使用する

main.py
from bisect import bisect_left, bisect_right, insort_left, insort_right
from collections import defaultdict, Counter, deque
from copy import deepcopy
from functools import reduce, lru_cache
from itertools import product, accumulate, groupby, combinations
import sys
import os
def rI(): return int(sys.stdin.readline().rstrip())
def rLI(): return list(map(int,sys.stdin.readline().rstrip().split()))
def rI1(): return (int(sys.stdin.readline().rstrip())-1)
def rLI1(): return list(map(lambda a:int(a)-1,sys.stdin.readline().rstrip().split()))
def rS(): return sys.stdin.readline().rstrip()
def rLS(): return list(sys.stdin.readline().rstrip().split())
IS_LOCAL = int(os.getenv("ATCODER", "0"))==0
err = (lambda *args, **kwargs: print(*args, **kwargs, file=sys.stderr)) if IS_LOCAL else (lambda *args, **kwargs: None)



class segtree():
    n=1
    size=1
    log=2
    d=[0]
    op=None
    e=10**15
    def __init__(self,V,OP,E):
        self.n=len(V)
        self.op=OP
        self.e=E
        self.log=(self.n-1).bit_length()
        self.size=1<<self.log
        self.d=[E for i in range(2*self.size)]
        for i in range(self.n):
            self.d[self.size+i]=V[i]
        for i in range(self.size-1,0,-1):
            self.update(i)
    def set(self,p,x):
        assert 0<=p and p<self.n
        p+=self.size
        self.d[p]=x
        for i in range(1,self.log+1):
            self.update(p>>i)
    def get(self,p):
        assert 0<=p and p<self.n
        return self.d[p+self.size]
    def prod(self,l,r):
        assert 0<=l and l<=r and r<=self.n
        sml=self.e
        smr=self.e
        l+=self.size
        r+=self.size
        while(l<r):
            if (l&1):
                sml=self.op(sml,self.d[l])
                l+=1
            if (r&1):
                smr=self.op(self.d[r-1],smr)
                r-=1
            l>>=1
            r>>=1
        return self.op(sml,smr)
    def all_prod(self):
        return self.d[1]
    def max_right(self,l,f):
        assert 0<=l and l<=self.n
        assert f(self.e)
        if l==self.n:
            return self.n
        l+=self.size
        sm=self.e
        while(1):
            while(l%2==0):
                l>>=1
            if not(f(self.op(sm,self.d[l]))):
                while(l<self.size):
                    l=2*l
                    if f(self.op(sm,self.d[l])):
                        sm=self.op(sm,self.d[l])
                        l+=1
                return l-self.size
            sm=self.op(sm,self.d[l])
            l+=1
            if (l&-l)==l:
                break
        return self.n
    def min_left(self,r,f):
        assert 0<=r and r<=self.n
        assert f(self.e)
        if r==0:
            return 0
        r+=self.size
        sm=self.e
        while(1):
            r-=1
            while(r>1 and (r%2)):
                r>>=1
            if not(f(self.op(self.d[r],sm))):
                while(r<self.size):
                    r=(2*r+1)
                    if f(self.op(self.d[r],sm)):
                        sm=self.op(self.d[r],sm)
                        r-=1
                return r+1-self.size
            sm=self.op(self.d[r],sm)
            if (r& -r)==r:
                break
        return 0
    def update(self,k):
        self.d[k]=self.op(self.d[2*k],self.d[2*k+1])
    def __str__(self):
        return str([self.get(i) for i in range(self.n)])


class S:
    def __init__(self,val=0,size=0,res=0):
        self.val=val
        self.size=size
        self.res=res

def op(l,r):
    res = S()
    res.val=l.val+r.val
    res.size=l.size+r.size
    res.res=l.res+r.res
    res.res+=l.size*r.val
    res.res-=l.val*r.size
    # err("-"*20)
    # err(l.val,l.size,l.res)
    # err(r.val,r.size,r.res)
    # err(res.val,res.size,res.res)
    return res

def main():
    N,Q = rLI()
    A = rLI()
    B = [S() for _ in range(N+1)]
    s = S(0,1,0)
    B[0] = S(0,1,0)
    for i in range(1,N+1):
        s.val += A[i-1]
        B[i] = deepcopy(s)
    # err([b.val for b in B])
    seg = segtree(B,op,S(0,0,0))
    
    for _ in range(Q):  
        l, r = rLI()
        ans = seg.prod(l-1,r+1).res
        print(ans)
        
if __name__ == '__main__':
    main()

感想

問題見た直後は累積和することまではわかったけどその後が困難を極めた。
解説見た後はそりゃ思いつかんわっていう解法だった。
やっぱりE問題はなかなか攻略できませんなー…

余談:実装するときにセグ木を久々に使ったので使い方ほぼ忘れてたのは内緒

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?