はじめに
C++
の std::set
は素晴らしいライブラリです。自前でこれを実装しようとしてもなかなかここまで速いものは作れないでしょう。
std::set
の代替手段として tatyam さんのライブラリが有名ですが、これは std::set
がない Python
のためのやつなので、Python
ユーザー限定です。
Python
に順序付き集合がないことはよく話題に上がることなのですが、C++
の順序付き集合である std::set
も競プロをする上ではなかなか残念なので、C++
用の順序付き集合ライブラリを公開すればみなさん嬉しいかもと思い、公開するに至りました。
競技プログラミングにおける活用
std::set
にもつけ入る隙があります。
- ランダムアクセスが線形時間以上
- 集約や遅延評価が載っていない
- 値の重複を許さない(これは
std::map
と併用するなどで解決できますが)
これらを完全に補完したデータ構造を作りましたので、よければどうぞ。ただし、ライセンスは守ってくださいね。
データ構造の概要は こちらの記事 (平衡二分木入門 : Splay Tree) で説明しています。
自分で言うのもなんですが、超絶便利です。C++
の std::set
,std::multiset
,std::map
を完全にカバーできております。
ぼくの自作の多重集合 MyMultiSet
は $2$ つの要素 (Key,Value)
をもち、(Key,Value)
の辞書順で管理しています (例外アリ、後述)。これによって Key
や index の範囲を絞って、その内の Key
や Value
の集約 (総和,Min,Max) を取得できます。
ソースコード
使い方は後述します。あとコメントの英語が拙いのは気にしないでください。雰囲気がわかればいいのです。
ライセンスは守ってくださいね (このライブラリの冒頭に以下を書くだけで良いです)。
- Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024. Released under the MIT license(https://opensource.org/licenses/mit-license.php)
ライセンスに従わない利用を確認した場合、ライセンスに従わない利用をやめるよう指示するなどの処置を取る場合があります。
#include<iostream>
#include<cassert>
/*
Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024.
Released under the MIT license(https://opensource.org/licenses/mit-license.php)
*/
template<class type_key , class type_value>
class MyMultiSet{
private:
struct SplayNode{
SplayNode *parent = nullptr;
SplayNode *left = nullptr;
SplayNode *right = nullptr;
type_key Key;
type_value Value;
type_key Sum_key;
type_value Min_val,Max_val,Sum_val;
int SubTreeSize = 1;
private:
bool copied_instance = false;
public:
SplayNode copy(){
assert(copied_instance == false);
SplayNode res = *this;
res.left = nullptr;
res.right = nullptr;
res.parent = nullptr;
res.copied_instance = true;
return res;
}
SplayNode(){}
SplayNode(type_key key_ , type_value val_){
Key = key_;
Value = val_;
update();
}
void rotate(){
if(this->parent->parent != nullptr){
if(this->parent == this->parent->parent->left)this->parent->parent->left = this;
else this->parent->parent->right = this;
}
this->parent->eval();
this->eval();
if(this->parent->left == this){
this->parent->left = this->right;
if(this->right != nullptr)this->right->parent = this->parent;
this->right = this->parent;
this->parent = this->right->parent;
this->right->parent = this;
this->right->update();
}else{
this->parent->right = this->left;
if(this->left != nullptr)this->left->parent = this->parent;
this->left = this->parent;
this->parent = this->left->parent;
this->left->parent = this;
this->left->update();
}
this->update();
return;
}
int state(){
if(this->parent == nullptr)return 0;
this->parent->eval();
if(this->parent->left == this)return 1;
else if(this->parent->right == this)return 2;
return 0;
}
void splay(){
while(this->parent != nullptr){
if(this->parent->state() == 0){
this->rotate();
break;
}
if( this->parent->state() == this->state() )this->parent->rotate();
else this->rotate();
this->rotate();
}
this->update();
return;
}
void update(){
assert(copied_instance == false);
this->eval();
this->SubTreeSize = 1;
this->Sum_key = this->Key;
this->Max_val = this->Sum_val = this->Min_val = this->Value;
if(this->left != nullptr){
this->left->eval();
this->SubTreeSize += this->left->SubTreeSize;
if(this->left->Min_val < this->Min_val)this->Min_val = this->left->Min_val;
if(this->left->Max_val > this->Max_val)this->Max_val = this->left->Max_val;
this->Sum_key += this->left->Sum_key;
this->Sum_val += this->left->Sum_val;
}
if(this->right != nullptr){
this->right->eval();
this->SubTreeSize += this->right->SubTreeSize;
if(this->right->Min_val < this->Min_val)this->Min_val = this->right->Min_val;
if(this->right->Max_val > this->Max_val)this->Max_val = this->right->Max_val;
this->Sum_key += this->right->Sum_key;
this->Sum_val += this->right->Sum_val;
}
return;
}
void eval(){
assert(copied_instance == false);
}
};
inline static constexpr bool CompareNode(SplayNode *a , SplayNode *b , bool paired_compare){
a->eval();
b->eval();
if(!paired_compare)return a->Key <= b->Key;
if(a->Key == b->Key)return bool(a->Value <= b-> Value);
return bool(a->Key < b->Key);
}
SplayNode *get_sub(int index , SplayNode *root){
if(root==nullptr)return root;
SplayNode *now = root;
while(true){
now->eval();
int left_size = 0;
if(now->left != nullptr)left_size = now->left->SubTreeSize;
if(index < left_size)now = now->left;
else if(index > left_size){
now = now->right;
index -= left_size+1;
}else break;
}
now->splay();
return now;
}
SplayNode *merge(SplayNode *leftRoot , SplayNode *rightRoot){
if(leftRoot!=nullptr)leftRoot->update();
if(rightRoot!=nullptr)rightRoot->update();
if(leftRoot == nullptr)return rightRoot;
if(rightRoot == nullptr)return leftRoot;
rightRoot = get_sub(0,rightRoot);
rightRoot->left = leftRoot;
leftRoot->parent = rightRoot;
rightRoot->update();
return rightRoot;
}
std::pair<SplayNode*,SplayNode*> split(int leftnum, SplayNode *root){
if(leftnum<=0)return std::make_pair(nullptr , root);
if(leftnum >= root->SubTreeSize)return std::make_pair(root, nullptr);
root = get_sub(leftnum , root);
SplayNode *leftRoot = root->left;
SplayNode *rightRoot = root;
if(rightRoot != nullptr)rightRoot->left = nullptr;
if(leftRoot != nullptr)leftRoot->parent = nullptr;
leftRoot->update();
rightRoot->update();
return std::make_pair(leftRoot,rightRoot);
}
std::pair<SplayNode*,int> bound_sub(SplayNode* Node , SplayNode *root , bool lower , bool paired_compare){
if(root == nullptr)return std::make_pair(root,0);
SplayNode *now = root;
Node->update();
bool satisfy = false;
while(true){
now->eval();
if(lower)satisfy = !CompareNode(Node,now,paired_compare);
else satisfy = CompareNode(now,Node,paired_compare);
if(satisfy == true && now->right != nullptr)now = now->right;
else if(satisfy == false && now->left != nullptr)now = now->left;
else break;
}
int res = 0;
if(satisfy)res = 1;
now->splay();
if(now->left != nullptr)res += now->left->SubTreeSize;
return std::make_pair(now ,res);
}
SplayNode *insert_sub(SplayNode *NODE , SplayNode *root , bool paired_compare){
NODE->update();
if(root == nullptr)return NODE;
root = bound_sub(NODE,root,true,paired_compare).first;
root->eval();
if(!CompareNode(NODE , root , paired_compare)){
if(root->right != nullptr)root->right->parent = NODE;
NODE->right = root->right;
root->right = nullptr;
NODE->left = root;
}else{
if(root->left != nullptr)root->left->parent = NODE;
NODE->left = root->left;
root->left = nullptr;
NODE->right = root;
}
root->parent = NODE;
root->update();
NODE->update();
return NODE;
}
protected:
SplayNode *m_Root = nullptr;
inline static SplayNode* const m_bluff_object = new SplayNode();
inline static SplayNode* const BluffObject(type_key k , type_value v){
m_bluff_object->Key = k;
m_bluff_object->Value = v;
return m_bluff_object;
}
bool _paired = true;
void release(){while(m_Root != nullptr)this->Delete(0);}
void init(){
m_Root = nullptr;
_paired = true;
}
const SplayNode* const begin(){
if(size() == 0)return nullptr;
m_Root = get_sub(0,m_Root);
return m_Root;
}
public:
MyMultiSet(){init();}
~MyMultiSet(){release();}
MyMultiSet(const MyMultiSet<type_key,type_value> & x) = delete ;
MyMultiSet& operator = ( const MyMultiSet<type_key,type_value> & x) = delete ;
MyMultiSet ( MyMultiSet<type_key,type_value>&& x){assert(0);}
MyMultiSet& operator = ( MyMultiSet<type_key,type_value>&& x){assert(0);}
void copy(MyMultiSet<type_key,type_value>& x){
if(this->begin() == x.begin())return;
release();
init();
for(int i=0;i<x.size();i++){
SplayNode t=x.get(i);
this->insert_pair(t.Key,t.Value);
}
this->_paired = x._paired;
}
int size(){
if(m_Root == nullptr)return 0;
return m_Root->SubTreeSize;
}
SplayNode get(int i){
assert(0 <= i && i < size());
m_Root = get_sub(i,m_Root);
return m_Root->copy();
}
SplayNode GetRange(int l , int r){
assert(0 <= l && l < r && r <= size());
std::pair<SplayNode*,SplayNode*> tmp = split(r,m_Root);
SplayNode* rightRoot = tmp.second;
tmp = split(l,tmp.first);
SplayNode res = tmp.second->copy();
m_Root = merge(merge(tmp.first,tmp.second),rightRoot);
return res;
}
void insert( type_key key_ ){
_paired = false;
m_Root = insert_sub(new SplayNode(key_,type_value(0)) ,m_Root , false);
return;
}
void insert_pair( type_key key_ , type_value val_){
assert(_paired);
m_Root = insert_sub(new SplayNode(key_,val_) ,m_Root,true);
return;
}
void Delete(int index){
assert(0 <= index && index < size());
SplayNode *center = get_sub(index,m_Root);
SplayNode *leftRoot = center->left;
SplayNode *rightRoot = center->right;
if(leftRoot != nullptr)leftRoot->parent = nullptr;
if(rightRoot != nullptr)rightRoot->parent = nullptr;
center->left = nullptr;
center->right = nullptr;
center->update();
m_Root = merge(leftRoot,rightRoot);
delete center;
}
void erase(type_key key_){
int it = find(key_);
if(it!=-1)Delete(it);
return;
}
void erase_pair(type_key key_ , type_value val_){
assert(_paired);
int it = find_pair(key_ , val_);
if(it!=-1)Delete(it);
return;
}
int lower_bound(type_key x){
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,true,false);
m_Root = tmp.first;
return tmp.second;
}
int lower_bound_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,true,true);
m_Root = tmp.first;
return tmp.second;
}
int upper_bound(type_key x){
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,false,false);
m_Root = tmp.first;
return tmp.second;
}
int upper_bound_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,false,true);
m_Root = tmp.first;
return tmp.second;
}
int count(type_key x){
return upper_bound(x) - lower_bound(x);
}
int count_pair(type_key x , type_key y){
assert(_paired);
return upper_bound_pair(x,y) - lower_bound_pair(x,y);
}
int find(type_key x){
if(size() == 0)return -1;
if(count(x) == 0)return -1;
return lower_bound(x);
}
int find_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return -1;
if(count_pair(x,y) == 0)return -1;
return lower_bound_pair(x,y);
}
void Debug(){
std::cerr<<"DEBUG: -- Size = " << size() << std::endl;
for( int i = 0 ; i < size() ; i++)std::cerr<< get(i).Key << " ";
std::cerr<< std::endl;
if(_paired == false)return ;
for( int i = 0 ; i < size() ; i++)std::cerr<< get(i).Value << " ";
std::cerr<< std::endl;
}
SplayNode back(){assert(size()>0);return get(size()-1);}
SplayNode front(){assert(size()>0);return get(0);}
void pop_back(){assert(size()>0);Delete(size()-1);}
void pop_front(){assert(size()>0);Delete(0);}
SplayNode operator [](int index){return get(index);}
/*
Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024.
Released under the MIT license(https://opensource.org/licenses/mit-license.php)
*/
};
簡単な使用例
void usage_set(){
/*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Key のみ使用する例 (使わなくてもテンプレートで Value の型は指示しておく)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*/
// 使わなくてもテンプレートで Value の型は指示しておく
MyMultiSet<int ,int> SingleSet;
// Key のみの insert を実行すると Value に関して未定義になる
SingleSet.insert(12);
SingleSet.insert(-5);
SingleSet.insert(-2);
SingleSet.insert(1);
SingleSet.insert(8);
SingleSet.insert(10);
SingleSet.insert(8);
SingleSet.insert(-2);
// 集合は {-5,-2,-2,1,8,8,10,12} となっている
// SingleSet.insert_pair(0,0);
// -> Value が未定義の時に Value に関するメソッドを呼ぶとエラー
// 1 番目 (0-index) の要素
std::cerr << SingleSet[1].Key << std::endl;// -2
// 8 未満の要素数
std::cerr << SingleSet.lower_bound(8) << std::endl;// 4
// 8 以下の要素数
std::cerr << SingleSet.upper_bound(8) << std::endl;// 6
// -2 以上 11 未満の要素の総和
std::cerr <<
SingleSet.GetRange(
SingleSet.lower_bound(-2),
SingleSet.lower_bound(11)
).Sum_key
<< std::endl;// 23
// -2 を一つだけ削除
SingleSet.erase(-2);
// -2 の個数を出力
std::cerr <<
SingleSet.upper_bound(-2) - SingleSet.lower_bound(-2)
<< std::endl;// 1
/*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Key , Value の両方を使用する例
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*/
MyMultiSet<int ,int> PairSet;
// ペアの挿入
PairSet.insert_pair(1,5);
PairSet.insert_pair(-2,4);
PairSet.insert_pair(-9,-2);
PairSet.insert_pair(-2,-9);
PairSet.insert_pair(3,1);
PairSet.insert_pair(1,-7);
PairSet.insert_pair(-2,10);
PairSet.insert_pair(8,6);
PairSet.insert_pair(7,5);
// 集合は (Key,Value) の辞書順
// { (-9,-2) , (-2,-9) , (-2,4) , (-2,10) , (1,-7) , (1,5) , (3,1) , (7,5) , (8,6) }
// 1 番目 (0-index) の要素
std::cerr << PairSet[1].Key << std::endl;// -2
std::cerr << PairSet[1].Value << std::endl;// -9
// Key が 1 未満の要素数
std::cerr << PairSet.lower_bound(1) << std::endl;// 4
// (Key,Value) が (1,5) 未満の要素数
std::cerr << PairSet.lower_bound_pair(1,5) << std::endl;// 5
// Key が 1 以下の要素数
std::cerr << PairSet.upper_bound(1) << std::endl;// 6
// (Key,Value) が (1,1) 以下の要素数
std::cerr << PairSet.upper_bound_pair(1,1) << std::endl;// 5
auto node_copy =
PairSet.GetRange(
PairSet.lower_bound(-2),
PairSet.upper_bound(-2));
// Key が -2 である要素の Value の Sum,Min,Max
std::cerr
<< node_copy.Sum_val << " , "
<< node_copy.Min_val << " , "
<< node_copy.Max_val
<< std::endl;// 5 , -9 , 10
// Key が 1 であるものを 1 つ削除
// 複数ある時、どれが削除されるかを強く保証するつもりはない
PairSet.erase(1);
// Value を指定して削除
PairSet.erase_pair(-2,5);
node_copy =
PairSet.GetRange(
PairSet.lower_bound(-2),
PairSet.lower_bound(7));
// Key が -2 以上 7 未満の要素の Value の Sum
std::cerr << node_copy.Sum_val << std::endl;// 11
}
使用例 1 ( ABC281-E )
実行時間 379ms ( TL : 2000ms )
挿入,削除,アクセス全てが $O(\log{N})$ 時間です。ある範囲の Key
の総和を取得して答えることもできます。
#include<iostream>
#include "MyMultiSet.hpp"
using std::cout ,std::endl , std::cin;
// ABC281-E (https://atcoder.jp/contests/abc281/tasks/abc281_e)
int main(){
int n , m , k;cin >> n >> m >> k;
long long A[200002];
MyMultiSet<long long,long long> S;
for(int i = 0 ; i < n ; i++)cin >> A[i];
for(int i = 0 ; i < m ; i++)S.insert(A[i]);
cout << S.GetRange(0,k).Sum_key << " ";
for(int i = 1 ; i < n-m+1 ; i++ ){
S.erase(A[i-1]);
S.insert(A[i+m-1]);
cout << S.GetRange(0,k).Sum_key << " ";
}
cout << endl;
return 0;
}
使用例 2 (Library Checker - Predecessor Problem)
実行時間 5355ms ( TL: 10000ms )
std::set
ほど速くはありませんが、
- $1\leq N\leq10^7$
- $1\leq Q \leq 10^6$
の制約でも $5$ 秒程度で終了します。
#include<iostream>
#include<string>
#include "MyMultiSet.hpp"
using std::cout , std::endl , std::cin , std::string;
//Library Checker - Predecessor Problem (https://judge.yosupo.jp/problem/predecessor_problem)
int main(){
int n,q;cin >> n >> q;
string t;cin >> t;
MyMultiSet<int,int> S;
for(int i = 0 ; i < t.size() ; i++){
if(t[i] != '0')S.insert(i);
}
while(q-->0){
int qt;cin >> qt;
int k;cin >> k;
if(qt == 0){
if(S.find(k) == -1)S.insert(k);
}else if(qt == 1)S.erase(k);
else if(qt == 2)cout << int(S.find(k) != -1) << endl;
else if(qt == 3){
int it = S.lower_bound(k);
if(it >= S.size())cout << -1 << endl;
else cout << S[it].Key << endl;
}else{
int it = S.upper_bound(k)-1;
if(it < 0) cout << -1 << endl;
else cout << S[it].Key << endl;
}
}
return 0;
}
使用例 3 ( ABC367-D )
実行時間 1102ms ( TL: 2000ms )
Value
に値を入れておけば、(Key,Value)
の辞書順で要素を調べることができる。Value
の集約も計算できる。詳細は後述。
#include<iostream>
#include<vector>
using std::cout,std::cin,std::endl;
using std::vector;
// ABC367-D(https://atcoder.jp/contests/abc367/tasks/abc367_d)
int main(){
int n , m;cin >> n >> m;
vector<long long> a(n) , r(n+1,0);
for(long long & x : a)cin >> x;
for(int i = 0 ; i < n ; i++)r[i+1] = r[i]+a[i];
MyMultiSet<long long , long long> S;
for(int i = 0 ; i <= n ; i++ )S.insert_pair(r[i]%m,-n + i);
long long accum = r[n];
long long ans = 0;
for(int i = 1 ; i <= n ; i++){
accum += a[i-1];
accum%=m;
ans += S.upper_bound(accum) - S.lower_bound_pair(accum,i-n+1);
S.insert_pair(accum,i);
}
cout << ans << endl;
return 0;
}
概要
type_key
型の Key
と type_value
型の Value
を持つ順序付き集合で、(Key , Value) が辞書順にソートされている。ただし、Value を無視する場合は辞書順ではなく、Key
の順序でソートされる。
- C++ の
std::set
とは異なり、Key
の重複を許す (Value
も当然重複 OK )。 -
get(i)
や[i]
でi
番目のノードのコピーを 0-index で取得。ただし隣接頂点へのアクセス (ポインタ) が封印されたものを返す。 -
Delete(i)
で小さい順でi
番目の要素を削除する。 -
GetRange(l,r)
は要素の辞書順の半開区間[l,r)
をカバーする部分木の根のコピーを返す。get()
同様に、隣接頂点のポインタは封印されている。-
GetRange(l,r).Sum_val
のようにして[l,r)
の持つ要素のモノイド積を取得する
-
(Key,Value) に関して以下の操作が可能
-
insert_pair(k,v)
-
(k,v)
を持つノードを追加
-
-
erase_pair(k,v)
-
(k,v)
を持つノードを(存在すれば)削除する。
-
-
upper_bound_pair(k,v)
-
(Key,Value)
が辞書順で(k,v)
以下の要素数を返す
-
-
lower_bound_pair(k,v)
-
(Key,Value)
が辞書順で(k,v)
未満の要素数を返す
-
-
find_pair(k,v)
-
(Key,Value)
が(k,v)
である要素の index を返す。存在しなければ-1
を返す (0-index)。
-
ノードの Key だけに注目して、通常の set のように振る舞わせることもできる。
-
insert(k)
-
Key = k
である要素を追加する。ただし、Value
を指定しないのでValue
は未定義とする。
-
-
erase(k)
-
Key = k
である要素を一つ削除する。ただし、Value
について特に指定しないことに注意。
-
-
upper_bound(k)
-
Key
がk
以下の要素数を返す
-
-
lower_bound(k)
-
Key
がk
未満の要素数を返す
-
-
find(k)
-
Key
がk
である要素の index を返す。存在しなければ-1
を返す (0-index)。
-
ペアを持つノードとKeyしか持たないノードが混在するといけない
-
insert()
を呼び出した時点で、Value
が未定義の要素が存在することになる-
insert()
を呼び出した後はupper_bound_pair
,find_pair
,RangeValueMaxQuery
など、Value
に関する関数を呼び出すとランタイムエラーになるようにしました。(ただし、Key
だけに関係する関数は変わらず使用できる。
-
注意点
コピーを禁止しているので、vector
など STL
コンテナに乗せるのは非推奨ですが、競プロでの実用で問題になることは珍しいかも。ただし、コピーを回避して実装しないとダメ。
初期化値を与えて初期化はダメ
// これは NG !!!
vector<MyMultiSet<int,int>> SV(100 , MyMultiSet<int,int>());
デフォルトコンストラクタの使用を指示しよう
// これは OK !!!
vector<MyMultiSet<int,int>> SV(100);
その他
- CompareNode() を変更することで要素の並び順を変更できる。
- 必要があれば
eval()
に遅延評価を実装して良い。- 順序付きなので、評価した後に順序が崩れないように制約をつける必要がある。
高速化手段
-
SplayNode
のupdate
の、不要な集約の記述を消す (SubTreeSize
は絶対必要)。 - デストラクタを消す (競プロ文脈限定)。
- など