セグメント木
データ構造の一種です。
完全二分木であり、区間に関する処理を扱うのに適しています。
計算量はO(logN)です。
データ構造ですが、問題によって複雑なカスタマイズを求めまれます。
- 8の要素数の区間を処理するために必要な配列
8(n_とします)以上の2^xの中で最小の値nを求めます。
8の要素数を処理するのに2^4 * 2 - 1の要素数が必要になります。
C++
void init(int n_){
n = 1;
while(n < n_) n *= 2;
for(int i=0; i < 2 * n - 1; i++) dat[i] = INF;
}
- 親から子へ
C++
dat[k*2+1] // 左辺
dat[k*2+2] // 右辺
- 子から親へ
C++
dat[(k-1)/2] // 親
- 値の更新
今回は例として区間内の最小値を更新する処理になっています。
最大値の場合はminをmaxにする。
初期化時の値を制約上の最小値などに変更してください。
C++
// name:
// update
// proc:
// 木の更新の処理
// param:
// k = 更新したい区間の要素のindex
// a = 更新する値
// return:
// なし
void update(int k, int a){
// 根の要素数を加算して葉のindexにする
k += n - 1;
// 葉の値を更新する
dat[k] = a;
// 葉から根の値を更新する
while(k > 0){
// 親のindexを求める
k = (k - 1) / 2;
// 親の要素を、子の要素の値で更新する
dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2]);
}
}
ここからはAOJの問題を通じて見ていきましょう。
DSL_2_A
Range Minimum Query (RMQ)
セグメントツリーの問題です。
a_0, a_1, a_2, a_3, ......, a_{N-1}
の時、
- a_x = vと値を更新
- a_l, ...., a_rの中で最小の値を出力
という問題です。
C++
#include <bits/stdc++.h>
using namespace std;
#define INF 2147483647
int n;
int dat[2*100000 * 3];
void init(int n_){
n = 1;
while(n < n_) n *= 2;
for(int i=0; i < 2 * n - 1; i++) dat[i] = INF;
}
void update(int k, int a){
k += n - 1;
dat[k] = a;
while(k > 0){
k = (k - 1) / 2;
dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2]);
}
}
// name:
// query
// proc:
// クエリの処理
// dfsを使用して、kから始まる頂点から探索していきます。
// param:
// a = 求めたい区間の左端
// b = 求めたい区間の右端
// k = 求めたい区間の親の頂点である要素
// l = 根の要素で最も小さいindex
// r = 根の要素で最も大きいindex + 1
// return:
// 区間内の要素の最も小さい値
int query(int a, int b, int k, int l, int r){
// 範囲外か判定
if (r <= a || b <= l) return INF;
// 区間内の最小の値を返す
if (a <= l && r <= b) return dat[k];
else{
// 左辺の子を探索
int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
// 右辺の子を探索
int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
// 左辺、右辺で小さい値を返す
return min(vl, vr);
}
}
int main(void){
int N, Q;
cin >> N >> Q;
init(N);
for(int i=0; i<Q; i++){
int t;
cin >> t;
if(t==0){
int x, v;
cin >> x >> v;
update(x, v);
}
if(t==1){
int l, r;
cin >> l >> r;
cout << query(l, r+1, 0, 0, n) << endl;
}
}
return 0;
}
DSL_2_B
Range Sum Query
区間和の問題です。
前回と異なる箇所は区間和の箇所です。
- 初期値
dat[i]を0で初期化します。
C++
for(int i=0; i < 2 * n - 1; i++) dat[i] = 0;
- 更新
代入でしたが、今回は+演算子でaの値を加算します。
C++
dat[k] += a;
最小値はmin関数でしたが、今回は+演算子で子の合計を代入します。
C++
dat[k] = dat[k * 2 + 1] + dat[k * 2 + 2];
- クエリ
return min(vl, vr);
がreturn vl + vr;
になっていますね。
C++
int query(int a, int b, int k, int l, int r){
if (r <= a || b <= l) return 0;
if (a <= l && r <= b) return dat[k];
else{
int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
return vl + vr;
}
}
実装したコードです。
C++
#include <bits/stdc++.h>
using namespace std;
int n;
int dat[2*100000 * 3];
void init(int n_){
n = 1;
while(n < n_) n *= 2;
for(int i=0; i < 2 * n - 1; i++) dat[i] = 0;
}
void update(int k, int a){
k += n - 1;
dat[k] += a;
while(k > 0){
k = (k - 1) / 2;
dat[k] = dat[k * 2 + 1] + dat[k * 2 + 2];
}
}
int query(int a, int b, int k, int l, int r){
if (r <= a || b <= l) return 0;
if (a <= l && r <= b) return dat[k];
else{
int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
return vl + vr;
}
}
int main(void){
// cout << __cplusplus << endl;
int N, Q;
cin >> N >> Q;
init(N);
for(int i=0; i<Q; i++){
int t;
cin >> t;
if(t==0){
int x, v;
cin >> x >> v;
x--;
update(x, v);
}
if(t==1){
int l, r;
cin >> l >> r;
l--;
cout << query(l, r, 0, 0, n) << endl;
}
}
return 0;
}
練習問題
B - Fenwick Tree
Range Sum Query
区間和の問題です。
C++
#include <bits/stdc++.h>
using namespace std;
long long n;
vector<long long> dat;
void init(long long n_){
n = 1;
while(n < n_) n *= 2;
dat.resize(2 * n, 0);
for(long long i=0; i < 2 * n - 1; i++) dat[i] = 0;
}
void update(long long k, long long a){
k += n - 1;
dat[k] += a;
while(k > 0){
k = (k - 1) / 2;
dat[k] = dat[k * 2 + 1] + dat[k * 2 + 2];
}
}
long long query(long long a, long long b, long long k, long long l, long long r){
if (r <= a || b <= l) return 0;
if (a <= l && r <= b) return dat[k];
else{
long long vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
long long vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
return vl + vr;
}
}
int main(void){
// cout << __cplusplus << endl;
int N, Q;
cin >> N >> Q;
init(N);
for(int i=0; i<N; i++){
long long a;
cin >> a;
update(i, a);
}
for(int i=0; i<Q; i++){
long long t;
cin >> t;
if(t==0){
long long x, v;
cin >> x >> v;
update(x, v);
}
if(t==1){
long long l, r;
cin >> l >> r;
cout << query(l, r, 0, 0, n) << endl;
}
}
return 0;
}
F - Range Xor Query
xorのセグメントツリーです。
C++
#include <bits/stdc++.h>
using namespace std;
int n;
int dat[5*100000 * 3];
void init(int n_){
n = 1;
while(n < n_) n *= 2;
for(int i=0; i < 2 * n - 1; i++) dat[i] = 0;
}
void update(int k, int a){
k += n - 1;
dat[k] ^= a;
while(k > 0){
k = (k - 1) / 2;
dat[k] = dat[k * 2 + 1] ^ dat[k * 2 + 2];
}
}
int query(int a, int b, int k, int l, int r){
if (r <= a || b <= l) return 0;
if (a <= l && r <= b) return dat[k];
else{
int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
return vl ^ vr;
}
}
int main(void){
// cout << __cplusplus << endl;
int N, Q;
cin >> N >> Q;
init(N);
for(int i=0; i<N; i++){
int a;
cin >> a;
update(i, a);
}
for(int i=0; i<Q; i++){
int t;
cin >> t;
if(t==1){
int x, v;
cin >> x >> v;
x--;
update(x, v);
}
if(t==2){
int l, r;
cin >> l >> r;
l--;
cout << query(l, r, 0, 0, n) << endl;
}
}
return 0;
}