1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実装しながら学ぶセグ木(コードと図解付)

Last updated at Posted at 2024-12-16

 セグ木の概念だけではなく、実際に手を動かして仕組みを知りたいという人へ。

 ここでは、区間和を求めるセグ木を作っていきたいと思います。配列は何でもいいですが、例えば {1,2,3,4,5,6,7} という配列を題材にしていきましょう。

配列をカバーする最小の2冪数を求める

 まず、$1,2,4,8,\cdots,2^n$ の中で、配列をカバーする最小のものを求める必要があります。サンプルの配列の大きさは $7$ なので、この場合は $8$ になります。

int main(){
    vector<int> a = {1,2,3,4,5,6,7};
    int n = a.size(); // 7
    int x = 1;
    while(x < a.size()){
        x *= 2;
    }
    cout << x << endl; // 8;
}

その2倍の配列を作る

 実際に必要なのはその $2$ 倍の大きさになります。

    vector<int> seg(2*x);

 次に、その右半分に初期値を詰めていきます。これは、ちょうど配列をちょうど x だけシフトした位置に置きます。

    for(int i=0;i<n;i++){
        seg[i+x] = a[i];
    }

 この時点で、seg は以下のような配列になります。

seg = 0 0 0 0 0 0 0 0 1 2 3 4 5 6 7 0

初期状態を構築する

 次に、左半分を埋めていきます。元の配列は $8,9,10,11,12,13,14,15$ 番目のノードに入っています。$8,9$ の親は $4$、$10,11$ の親は $5$、といったように、自分のノードを $2$ で割ったものが親のノードになります。

    for(int i=x-1;i>0;i--){
        seg[i] = seg[2*i] + seg[2*i+1];
    }

 この時点で、seg は以下のような配列になります。

seg = 0 28 10 18 3 7 11 7 1 2 3 4 5 6 7 0

 これでセグ木の構築が完了しました。

 図で表すと、以下のようになります。
image.png

クエリを求める

 ここからが難しいです。例えば、$2$ 番目から $6$ 番目の値の合計を求めたいとします。サンプルでは $2+3+4+5+6 = 20$ です。$0$-indexed の半開区間で表すと [1,6) となります。

 実装的には、はじめに親ノード($1$)に対して情報を求めます。

  1. クエリが今のノードを完全に含む場合→そのノードの値を返す
  2. クエリが今のノードを完全に含まない場合→単位元(サンプルでは 0)を返す
  3. そうではない場合(中途半端に含む場合)→子ノード達に丸投げする

 という手順を繰り返せば、いつかは「1.」か「2.」のどちらかの値が返ってきて、最初のクエリの値が求められます。なお、ある親ノード $i$ の子ノードは、$2 \cdot i$ 及び $2 \cdot i+1$ になります。

 実装としては、以下の情報を持って再帰的な探索を行います。

  • 今どのノードにいるか
  • 今のノードの左端
  • 今のノードの右端

 これは、実装例を見た方がわかりやすいかもです。

    int qL = 1;
    int qR = 6;
    auto dfs = [&](auto dfs,int now,int L,int R)->int{
        if(qL<=L && R<=qR){
            return seg[now];
        }else if(R<=qL || qR<=L){
            return 0;
        }else{
            int mid = (L+R)/2;
            int left_child = dfs(dfs,2*now,L,mid);
            int right_child = dfs(dfs,2*now+1,mid,R);
            return left_child + right_child;
        }
    };
    cout << dfs(dfs,1,0,x) << endl; // 20

 図にすると、以下のようになります。

image.png

更新する

 サンプルの $4$ 番目のノードを 8 に更新するとします。

 この時、現在ノードだけでなく、上流のノードも更新する必要があります。先にも述べた通り、ある子ノード $i$ の親ノードは $\lfloor \frac{i}{2} \rfloor$ なので、現在ノードが $1$ になるまで再帰処理を繰り返せばいいです。

    int update = 8;
    int pos = 3+x;
    seg[pos] = update;
    pos /= 2;
    while(pos>0){
        seg[pos] = seg[2*pos] + seg[2*pos+1];
        pos /= 2;
    }

 この時点で、seg は以下のような配列になります。

seg = 0 32 14 18 3 11 11 7 1 2 3 8 5 6 7 0

 これでセグ木の更新が完了しました。

 図にすると、以下のようになります。

image.png

 もう一度同じ範囲でクエリを求めると、$24$ になり、正しく更新されていることがわかります。

    cout << dfs(dfs,1,0,x) << endl; // 24

今までのコードまとめ

int main(){
    vector<int> a = {1,2,3,4,5,6,7};
    int n = a.size();
    int x = 1;
    while(x < a.size()){
        x *= 2;
    }
    cout << x << endl;
    vector<int> seg(2*x,0);
    for(int i=0;i<n;i++){
        seg[i+x] = a[i];
    }
    for(int i=x-1;i>0;i--){
        seg[i] = seg[2*i] + seg[2*i+1];
    }
    int qL = 1;
    int qR = 6;
    auto dfs = [&](auto dfs,int now,int L,int R){
        if(qL<=L && R<=qR){
            return seg[now];
        }else if(R<=qL || qR<=L){
            return 0;
        }else{
            int mid = (L+R)/2;
            int left_child = dfs(dfs,2*now,L,mid);
            int right_child = dfs(dfs,2*now+1,mid,R);
            return left_child + right_child;
        }
    };
    cout << dfs(dfs,1,0,x) << endl; // 20
    int update = 8;
    int pos = 3+x;
    seg[pos] = update;
    pos /= 2;
    while(pos>0){
        seg[pos] = seg[2*pos] + seg[2*pos+1];
        pos /= 2;
    }
    cout << dfs(dfs,1,0,x) << endl; // 24
}

 aqLqRupdatepos の値を変えてみると、任意の配列に対して、区間和の取得や一点更新が可能なセグ木であることがわかると思います。

ACLとの対応関係

 ACL のライブラリで定義している ope とは何かについて一応説明。

 op は操作であり、区間和で言うと和になります。サンプルで実装している部分だと、

        // 略
        seg[i] = seg[2*i] + seg[2*i+1];
        // 略
        }else{
            int mid = (L+R)/2;
            int left_child = dfs(dfs,2*now,L,mid);
            int right_child = dfs(dfs,2*now+1,mid,R);
            return left_child + right_child;
        }
        // 略
        seg[pos] = seg[2*pos] + seg[2*pos+1];

 ここになり、ここが足し算になるので、区間和になると言えます。

 e は単位元であり、相手の値に何も影響を与えないような値です。足し算だと 0 になります。サンプルで実装している部分だと、

    vector<int> seg(2*x,0);
    // 略
        }else if(R<=qL || qR<=L){
            return 0;
    // 略

 ここになり、区間和においてはここが 0 になります。

 これを ACL のセッティングで表すと、

int op(int a,int b){
    return a+b;
}
int e(){
    return 0;
}

 となります。

参考(区間最小値)

 同じサンプルに対して、今度は区間最小値を求めてみます。

int main(){
    vector<int> a = {1,2,3,4,5,6,7};
    int n = a.size();
    int x = 1;
    while(x < a.size()){
        x *= 2;
    }
    cout << x << endl;
    vector<int> seg(2*x,1e9);
    for(int i=0;i<n;i++){
        seg[i+x] = a[i];
    }
    for(int i=x-1;i>0;i--){
        seg[i] = min(seg[2*i],seg[2*i+1]);
    }
    int qL = 2;
    int qR = 5;
    auto dfs = [&](auto dfs,int now,int L,int R)->int{
        if(qL<=L && R<=qR){
            return seg[now];
        }else if(R<=qL || qR<=L){
            return 1e9;
        }else{
            int mid = (L+R)/2;
            int left_child = dfs(dfs,2*now,L,mid);
            int right_child = dfs(dfs,2*now+1,mid,R);
            return min(left_child,right_child);
        }
    };
    cout << dfs(dfs,1,0,x) << endl; // 3
    int update = 8;
    int pos = 2+x;
    seg[pos] = update;
    pos /= 2;
    while(pos>0){
        seg[pos] = min(seg[2*pos],seg[2*pos+1]);
        pos /= 2;
    }
    cout << dfs(dfs,1,0,x) << endl; // 4
}

 コードのどこが変わったかを見比べると、セグ木のやっていることがわかりやすいかもです。

参考(セグ木上の二分探索(max_right))

 あまり使わないかもしれませんが、いわゆる max_right についても、一応図とコードを。

image.png

int main(){
    vector<int> a = {1,2,3,4,5,6,7};
    int n = a.size();
    int x = 1;
    while(x < a.size()){
        x *= 2;
    }
    cout << x << endl;
    vector<int> seg(2*x,0);
    for(int i=0;i<n;i++){
        seg[i+x] = a[i];
    }
    for(int i=x-1;i>0;i--){
        seg[i] = seg[2*i] + seg[2*i+1];
    }
    int target = 15;
    int start = 1;
    int sum = 0;
    auto dfs = [&](auto dfs,int now,int L,int R)->int{
        if(start<=L && R<=n){
            if(sum + seg[now] <= target){
                sum += seg[now];
                return -1;
            }else{
                while(now<x){
                    if(sum + seg[now*2] <= target){
                        sum += seg[now*2];
                        now = now*2+1;
                    }else{
                        now = now*2;
                    }
                }
                return now-x;
            }
        }else if(R<=start || n<=L){
            return -1;
        }else{
            int mid = L+R;
            mid /= 2;
            int left_res = dfs(dfs,now*2,L,mid);
            if(left_res != -1){
                return left_res;
            }
            int right_res = dfs(dfs,now*2+1,mid,R);
            return right_res;
        }
    };
    int max_right = dfs(dfs,1,0,x);
    if(max_right == -1){
        cout << n << endl;
    }else{
        cout << max_right << endl; // 5
    }
}

 やっていること。

  1. 左から順番に、最適に分割したセグメントを、target を超えない範囲でどんどん右に足していく
  2. もし sumtarget を超えるような場合、下流の子供達を見ていく。左側の子供を足して target も超えなければ、それを足して右側にいく。そうでなければ、何も足さないで左側にいく
  3. ノードが最下層にいったとき、それが max_right のインデックスになる

 ただし、インデックスが一番右までいくとき、この探索は -1 を返すので、その時は個別に処理する必要があります。ここら辺、dfs 内でうまく処理する方法がみつかれば追記するかもです。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?