セグ木の概念だけではなく、実際に手を動かして仕組みを知りたいという人へ。
ここでは、区間和を求めるセグ木を作っていきたいと思います。配列は何でもいいですが、例えば {1,2,3,4,5,6,7}
という配列を題材にしていきましょう。
配列をカバーする最小の2冪数を求める
まず、$1,2,4,8,\cdots,2^n$ の中で、配列をカバーする最小のものを求める必要があります。サンプルの配列の大きさは $7$ なので、この場合は $8$ になります。
int main(){
vector<int> a = {1,2,3,4,5,6,7};
int n = a.size(); // 7
int x = 1;
while(x < a.size()){
x *= 2;
}
cout << x << endl; // 8;
}
その2倍の配列を作る
実際に必要なのはその $2$ 倍の大きさになります。
vector<int> seg(2*x);
次に、その右半分に初期値を詰めていきます。これは、ちょうど配列をちょうど x
だけシフトした位置に置きます。
for(int i=0;i<n;i++){
seg[i+x] = a[i];
}
この時点で、seg
は以下のような配列になります。
seg = 0 0 0 0 0 0 0 0 1 2 3 4 5 6 7 0
初期状態を構築する
次に、左半分を埋めていきます。元の配列は $8,9,10,11,12,13,14,15$ 番目のノードに入っています。$8,9$ の親は $4$、$10,11$ の親は $5$、といったように、自分のノードを $2$ で割ったものが親のノードになります。
for(int i=x-1;i>0;i--){
seg[i] = seg[2*i] + seg[2*i+1];
}
この時点で、seg
は以下のような配列になります。
seg = 0 28 10 18 3 7 11 7 1 2 3 4 5 6 7 0
これでセグ木の構築が完了しました。
クエリを求める
ここからが難しいです。例えば、$2$ 番目から $6$ 番目の値の合計を求めたいとします。サンプルでは $2+3+4+5+6 = 20$ です。$0$-indexed の半開区間で表すと [1,6)
となります。
実装的には、はじめに親ノード($1$)に対して情報を求めます。
- クエリが今のノードを完全に含む場合→そのノードの値を返す
- クエリが今のノードを完全に含まない場合→単位元(サンプルでは
0
)を返す - そうではない場合(中途半端に含む場合)→子ノード達に丸投げする
という手順を繰り返せば、いつかは「1.」か「2.」のどちらかの値が返ってきて、最初のクエリの値が求められます。なお、ある親ノード $i$ の子ノードは、$2 \cdot i$ 及び $2 \cdot i+1$ になります。
実装としては、以下の情報を持って再帰的な探索を行います。
- 今どのノードにいるか
- 今のノードの左端
- 今のノードの右端
これは、実装例を見た方がわかりやすいかもです。
int qL = 1;
int qR = 6;
auto dfs = [&](auto dfs,int now,int L,int R)->int{
if(qL<=L && R<=qR){
return seg[now];
}else if(R<=qL || qR<=L){
return 0;
}else{
int mid = (L+R)/2;
int left_child = dfs(dfs,2*now,L,mid);
int right_child = dfs(dfs,2*now+1,mid,R);
return left_child + right_child;
}
};
cout << dfs(dfs,1,0,x) << endl; // 20
図にすると、以下のようになります。
更新する
サンプルの $4$ 番目のノードを 8
に更新するとします。
この時、現在ノードだけでなく、上流のノードも更新する必要があります。先にも述べた通り、ある子ノード $i$ の親ノードは $\lfloor \frac{i}{2} \rfloor$ なので、現在ノードが $1$ になるまで再帰処理を繰り返せばいいです。
int update = 8;
int pos = 3+x;
seg[pos] = update;
pos /= 2;
while(pos>0){
seg[pos] = seg[2*pos] + seg[2*pos+1];
pos /= 2;
}
この時点で、seg
は以下のような配列になります。
seg = 0 32 14 18 3 11 11 7 1 2 3 8 5 6 7 0
これでセグ木の更新が完了しました。
図にすると、以下のようになります。
もう一度同じ範囲でクエリを求めると、$24$ になり、正しく更新されていることがわかります。
cout << dfs(dfs,1,0,x) << endl; // 24
今までのコードまとめ
int main(){
vector<int> a = {1,2,3,4,5,6,7};
int n = a.size();
int x = 1;
while(x < a.size()){
x *= 2;
}
cout << x << endl;
vector<int> seg(2*x,0);
for(int i=0;i<n;i++){
seg[i+x] = a[i];
}
for(int i=x-1;i>0;i--){
seg[i] = seg[2*i] + seg[2*i+1];
}
int qL = 1;
int qR = 6;
auto dfs = [&](auto dfs,int now,int L,int R){
if(qL<=L && R<=qR){
return seg[now];
}else if(R<=qL || qR<=L){
return 0;
}else{
int mid = (L+R)/2;
int left_child = dfs(dfs,2*now,L,mid);
int right_child = dfs(dfs,2*now+1,mid,R);
return left_child + right_child;
}
};
cout << dfs(dfs,1,0,x) << endl; // 20
int update = 8;
int pos = 3+x;
seg[pos] = update;
pos /= 2;
while(pos>0){
seg[pos] = seg[2*pos] + seg[2*pos+1];
pos /= 2;
}
cout << dfs(dfs,1,0,x) << endl; // 24
}
a
、qL
、qR
、update
、pos
の値を変えてみると、任意の配列に対して、区間和の取得や一点更新が可能なセグ木であることがわかると思います。
ACLとの対応関係
ACL のライブラリで定義している op
、e
とは何かについて一応説明。
op
は操作であり、区間和で言うと和になります。サンプルで実装している部分だと、
// 略
seg[i] = seg[2*i] + seg[2*i+1];
// 略
}else{
int mid = (L+R)/2;
int left_child = dfs(dfs,2*now,L,mid);
int right_child = dfs(dfs,2*now+1,mid,R);
return left_child + right_child;
}
// 略
seg[pos] = seg[2*pos] + seg[2*pos+1];
ここになり、ここが足し算になるので、区間和になると言えます。
e
は単位元であり、相手の値に何も影響を与えないような値です。足し算だと 0
になります。サンプルで実装している部分だと、
vector<int> seg(2*x,0);
// 略
}else if(R<=qL || qR<=L){
return 0;
// 略
ここになり、区間和においてはここが 0
になります。
これを ACL のセッティングで表すと、
int op(int a,int b){
return a+b;
}
int e(){
return 0;
}
となります。
参考(区間最小値)
同じサンプルに対して、今度は区間最小値を求めてみます。
int main(){
vector<int> a = {1,2,3,4,5,6,7};
int n = a.size();
int x = 1;
while(x < a.size()){
x *= 2;
}
cout << x << endl;
vector<int> seg(2*x,1e9);
for(int i=0;i<n;i++){
seg[i+x] = a[i];
}
for(int i=x-1;i>0;i--){
seg[i] = min(seg[2*i],seg[2*i+1]);
}
int qL = 2;
int qR = 5;
auto dfs = [&](auto dfs,int now,int L,int R)->int{
if(qL<=L && R<=qR){
return seg[now];
}else if(R<=qL || qR<=L){
return 1e9;
}else{
int mid = (L+R)/2;
int left_child = dfs(dfs,2*now,L,mid);
int right_child = dfs(dfs,2*now+1,mid,R);
return min(left_child,right_child);
}
};
cout << dfs(dfs,1,0,x) << endl; // 3
int update = 8;
int pos = 2+x;
seg[pos] = update;
pos /= 2;
while(pos>0){
seg[pos] = min(seg[2*pos],seg[2*pos+1]);
pos /= 2;
}
cout << dfs(dfs,1,0,x) << endl; // 4
}
コードのどこが変わったかを見比べると、セグ木のやっていることがわかりやすいかもです。
参考(セグ木上の二分探索(max_right))
あまり使わないかもしれませんが、いわゆる max_right
についても、一応図とコードを。
int main(){
vector<int> a = {1,2,3,4,5,6,7};
int n = a.size();
int x = 1;
while(x < a.size()){
x *= 2;
}
cout << x << endl;
vector<int> seg(2*x,0);
for(int i=0;i<n;i++){
seg[i+x] = a[i];
}
for(int i=x-1;i>0;i--){
seg[i] = seg[2*i] + seg[2*i+1];
}
int target = 15;
int start = 1;
int sum = 0;
auto dfs = [&](auto dfs,int now,int L,int R)->int{
if(start<=L && R<=n){
if(sum + seg[now] <= target){
sum += seg[now];
return -1;
}else{
while(now<x){
if(sum + seg[now*2] <= target){
sum += seg[now*2];
now = now*2+1;
}else{
now = now*2;
}
}
return now-x;
}
}else if(R<=start || n<=L){
return -1;
}else{
int mid = L+R;
mid /= 2;
int left_res = dfs(dfs,now*2,L,mid);
if(left_res != -1){
return left_res;
}
int right_res = dfs(dfs,now*2+1,mid,R);
return right_res;
}
};
int max_right = dfs(dfs,1,0,x);
if(max_right == -1){
cout << n << endl;
}else{
cout << max_right << endl; // 5
}
}
やっていること。
- 左から順番に、最適に分割したセグメントを、
target
を超えない範囲でどんどん右に足していく - もし
sum
がtarget
を超えるような場合、下流の子供達を見ていく。左側の子供を足してtarget
も超えなければ、それを足して右側にいく。そうでなければ、何も足さないで左側にいく - ノードが最下層にいったとき、それが
max_right
のインデックスになる
ただし、インデックスが一番右までいくとき、この探索は -1
を返すので、その時は個別に処理する必要があります。ここら辺、dfs
内でうまく処理する方法がみつかれば追記するかもです。