はじめに
C++
の std::set
は素晴らしいライブラリです。自前でこれを実装しようとしてもなかなかここまで速いものは作れないでしょう。
std::set
の代替手段として tatyam さんのライブラリが有名ですが、これは std::set
がない Python
のためのやつなので、Python
ユーザー限定です。
Python
に順序付き集合がないことはよく話題に上がることなのですが、C++
の順序付き集合である std::set
も競プロをする上ではなかなか残念なので、C++
用の順序付き集合ライブラリを公開すればみなさん嬉しいかもと思い、公開するに至りました。
競技プログラミングにおける活用
std::set
にもつけ入る隙があります。
- 要素アクセスが線形時間
- 集約や遅延評価が載っていない
- 値の重複を許さない(これは
std::map
と併用するなどで解決できますが)
これらを完全に補完したデータ構造を作りましたので、よければどうぞ。ただし、ライセンスは守ってくださいね。
コードは こちらの GitHub リポジトリ でも公開しています。データ構造の概要は こちらの記事 (平衡二分木入門 : Splay Tree) で説明しています。
自分で言うのもなんですが、超絶便利です。C++
の std::set
,std::multiset
,std::map
を完全にカバーできております。
ぼくの自作の多重集合 MyMultiSet
は $2$ つの要素 (Key,Value)
をもち、(Key,Value)
の辞書順で管理しています (例外アリ、後述)。これによって Key
や index の範囲を絞って、その内の Key
や Value
の集約 (総和,Min,Max) を取得できます。
ソースコード
使い方は後述します。あとコメントの英語が拙いのは気にしないでください。雰囲気がわかればいいのです。
ライセンスは守ってくださいね (このライブラリの冒頭に以下を書くだけで良いです)。
- Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024. Released under the MIT license(https://opensource.org/licenses/mit-license.php)
ライセンスに従わない利用を確認した場合、ライセンスに従わない利用をやめるよう指示するなどの処置を取る場合があります。
#include<iostream>
#include<cassert>
/*
このコメントは消さないでください。
Don't Remove this Comment !!
Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024.
Released under the MIT license(https://opensource.org/licenses/mit-license.php)
*/
template<class type_key , class type_value>
class MyMultiSet{
private:
struct SplayNode{
SplayNode *parent = nullptr;// parent node
SplayNode *left = nullptr;// left child
SplayNode *right = nullptr;// left child
type_key Key;// sorted key
type_value Value;// value (sorted if key is same)
type_key Sum_key;// Sum of Key in Subtree
type_value Min_val,Max_val,Sum_val;// Min,Max,Sum of Value in Subtree
int SubTreeSize = 1;// Size of Subtree under this node
private:
bool copied_instance = false;
public:
SplayNode copy(){
assert(copied_instance == false);
SplayNode res = *this;
res.left = nullptr;
res.right = nullptr;
res.parent = nullptr;
res.copied_instance = true;
return res;
}
SplayNode(){}
SplayNode(type_key key_ , type_value val_){
Key = key_;
Value = val_;
update();
}
// rotate ( this node - parent )
void rotate(){
if(this->parent->parent){
if(this->parent == this->parent->parent->left)this->parent->parent->left = this;
else this->parent->parent->right = this;
}
this->parent->eval();
this->eval();
if(this->parent->left == this){
this->parent->left = this->right;
if(this->right)this->right->parent = this->parent;
this->right = this->parent;
this->parent = this->right->parent;
this->right->parent = this;
this->right->update();
}else{
this->parent->right = this->left;
if(this->left)this->left->parent = this->parent;
this->left = this->parent;
this->parent = this->left->parent;
this->left->parent = this;
this->left->update();
}
this->update();
return;
}
// direction of this parent (left or right)
int state(){
if(this->parent == nullptr)return 0;
this->eval();
if(this->parent->left == this)return 1;
else if(this->parent->right == this)return 2;
return 0;
}
// bottom-up splay
void splay(){
while(bool(this->parent)){
if(this->parent->state() == 0){
this->rotate();
break;
}
if( this->parent->state() == this->state() )this->parent->rotate();
else this->rotate();
this->rotate();
}
this->update();
return;
}
// update data member
void update(){
assert(copied_instance == false);
this->eval();
this->SubTreeSize = 1;
this->Sum_key = this->Key;
this->Max_val = this->Sum_val = this->Min_val = this->Value;
// add left child
if(bool(this->left)){
this->left->eval();
this->SubTreeSize += this->left->SubTreeSize;
if(this->left->Min_val < this->Min_val)this->Min_val = this->left->Min_val;
if(this->left->Max_val > this->Max_val)this->Max_val = this->left->Max_val;
this->Sum_key += this->left->Sum_key;
this->Sum_val += this->left->Sum_val;
}
// add right child
if(bool(this->right)){
this->right->eval();
this->SubTreeSize += this->right->SubTreeSize;
if(this->right->Min_val < this->Min_val)this->Min_val = this->right->Min_val;
if(this->right->Max_val > this->Max_val)this->Max_val = this->right->Max_val;
this->Sum_key += this->right->Sum_key;
this->Sum_val += this->right->Sum_val;
}
return;
}
// evaluate Lazy Evaluation
void eval(){
// if it's necessary , write here.
assert(copied_instance == false);
}
};
/*
1. order of node's Key if [paired_compare] is false.
2. lexicographical order of node's (Key ,Value) if [paired_compare] is true.
*/
constexpr bool CompareNode(SplayNode *a , SplayNode *b , bool paired_compare = false){
a->eval();
b->eval();
if(!paired_compare)return a->Key <= b->Key;
else{
if(a->Key < b->Key)return true;
else if(a->Key == b->Key){
if(a->Value <= b-> Value)return true;
else return false;
}else return false;
}
}
// get [index]th node pointer
SplayNode *get_sub(int index , SplayNode *root){
if(root==nullptr)return root;
SplayNode *now = root;
while(true){
now->eval();
int left_size = 0;
if(now->left != nullptr)left_size = now->left->SubTreeSize;
if(index < left_size)now = now->left;
else if(index > left_size){
now = now->right;
index -= left_size+1;
}else break;
}
now->splay();
return now;
}
// merge 2 SplayTrees
SplayNode *merge(SplayNode *leftRoot , SplayNode *rightRoot){
if(leftRoot!=nullptr)leftRoot->update();
if(rightRoot!=nullptr)rightRoot->update();
if(bool(leftRoot ) == false)return rightRoot;
if(bool(rightRoot) == false )return leftRoot;
rightRoot = get_sub(0,rightRoot);
rightRoot->left = leftRoot;
leftRoot->parent = rightRoot;
rightRoot->update();
return rightRoot;
}
// split SplayTree at [leftnum]
std::pair<SplayNode*,SplayNode*> split(int leftnum, SplayNode *root){
if(leftnum<=0)return std::make_pair(nullptr , root);
if(leftnum >= root->SubTreeSize)return std::make_pair(root, nullptr);
root = get_sub(leftnum , root);
SplayNode *leftRoot = root->left;
SplayNode *rightRoot = root;
if(bool(rightRoot))rightRoot->left = nullptr;
if(bool(leftRoot))leftRoot->parent = nullptr;
leftRoot->update();
rightRoot->update();
return std::make_pair(leftRoot,rightRoot);
}
// remove [index]th node
std::pair<SplayNode*,SplayNode*> Delete_sub(int index , SplayNode *root){
if(bool(root) == false)return std::make_pair(root,root);
root = get_sub(index,root);
SplayNode *leftRoot = root->left;
SplayNode *rightRoot = root->right;
if(bool(leftRoot))leftRoot->parent = nullptr;
if(bool(rightRoot))rightRoot->parent = nullptr;
root->left = nullptr;
root->right = nullptr;
root->update();
return std::make_pair(merge(leftRoot,rightRoot) , root );
}
/*
between 2 SplayNodes [A] and [B] , we define following order.
- if [paired_compare] is false,
- [A] [<] [B] represent a order of these Keys.
- [A] [==] [B] represent these Keys are same
- if [paired_compare] is true,
- [A] [<] [B] represent a lexicographical order of these (Key , Value).
- [A] [==] [B] represent these (Key , Value) are same
This function finds the border index [B] which satisfies following.
1. if [lower] is true, for any [i] smaller than [B] , {[i]th node} [<] {[Node] argument}
2. if [lower] is false, for any [i] smaller than [B] , {[i]th node} [<] {[Node] argument} or {[i]th node} [==] {[Node] argument}
*/
std::pair<SplayNode*,int> bound_sub(SplayNode* Node , SplayNode *root , bool lower , bool paired_compare ){
if(bool(root) == false)return std::make_pair(root,0);
SplayNode *now = root;
int res = 0;
Node->update();
while(true){
now->eval();
bool satisfy = CompareNode(now,Node,paired_compare); // upper_bound (now <= Node)
if(lower)satisfy = !CompareNode(Node,now,paired_compare); // lower_bound (now < Node)
if(satisfy){
if(bool(now->right))now = now->right;
else {
res++;
break;
}
}else{
if(bool(now->left))now = now->left;
else break;
}
}
now->splay();
if(bool(now->left))res += now->left->SubTreeSize;
return std::make_pair(now ,res);
}
// insert [NODE]argument into SplayTree (in which nodes are sorted)
SplayNode *insert_sub(SplayNode *NODE , SplayNode *root , bool paired_compare){
NODE->update();
if(bool(root) == false)return NODE;
// find the border index [x] ( [x]th node [<] [NODE] argument ]
root = bound_sub(NODE,root,true,paired_compare).first;
root->eval();
if(!CompareNode(NODE , root , paired_compare)){
if(root->right != nullptr)root->right->parent = NODE;
NODE->right = root->right;
root->right = nullptr;
NODE->left = root;
}else{
if(root->left != nullptr)root->left->parent = NODE;
NODE->left = root->left;
root->left = nullptr;
NODE->right = root;
}
root->parent = NODE;
root->update();
NODE->update();
return NODE;
}
protected:
// root node of this tree
SplayNode *m_Root = nullptr;
// bluff node object (used as temporary node)
SplayNode *m_bluff_object = nullptr;
SplayNode* BluffObject(type_key k , type_value v){
if(m_bluff_object == nullptr)m_bluff_object = new SplayNode(type_key(0),type_value(0));
m_bluff_object->Key = k;
m_bluff_object->Value = v;
return m_bluff_object;
}
// flag of whether node's Values are defined
// (Values might be undefined if we use insert() function)
bool _paired = true;
void release(){while(m_Root != nullptr)this->Delete(0);}
void init(){
m_Root = nullptr;
_paired = true;
}
// pointer of leftmost node
const SplayNode* const begin(){
if(size() == 0)return nullptr;
m_Root = get_sub(0,m_Root);
return m_Root;
}
public:
MyMultiSet(){init();}
~MyMultiSet(){release();}
// don't copy this object
MyMultiSet(const MyMultiSet<type_key,type_value> & x) = delete ;
MyMultiSet& operator = ( const MyMultiSet<type_key,type_value> & x) = delete ;
MyMultiSet ( MyMultiSet<type_key,type_value>&& x){assert(0);}
MyMultiSet& operator = ( MyMultiSet<type_key,type_value>&& x){assert(0);}
// this function makes whole new SplayTree object from [x] argument
void copy(MyMultiSet<type_key,type_value>& x){
if(this->begin() == x.begin())return;
release();
init();
for(int i=0;i<x.size();i++){
SplayNode t=x.get(i);
this->insert_pair(t.Key,t.Value);
}
this->_paired = x._paired;
}
// tree size
int size(){
if(m_Root == nullptr)return 0;
return m_Root->SubTreeSize;
}
// get copy object of [i]th node
SplayNode get(int i){
assert(0 <= i && i < size());
m_Root = get_sub(i,m_Root);
return m_Root->copy();
}
// get copy object node which covers interval [l,r)
// Ex. we can get Sum of Key in [l,r)
SplayNode GetRange(int l , int r){
assert(0 <= l && l < r && r <= size());
std::pair<SplayNode*,SplayNode*> tmp = split(r,m_Root);
SplayNode* rightRoot = tmp.second;
tmp = split(l,tmp.first);// 部分木を取り出す。
SplayNode res = tmp.second->copy();
m_Root = merge(merge(tmp.first,tmp.second),rightRoot);
return res;
}
// insert key_
void insert( type_key key_ ){
_paired = false;// undefined Value was added
m_Root = insert_sub(new SplayNode(key_,type_value(0)) ,m_Root , false);
return;
}
// insert (key_ , value_)
void insert_pair( type_key key_ , type_value val_){
assert(_paired);
m_Root = insert_sub(new SplayNode(key_,val_) ,m_Root,true);
return;
}
// delete [index]th element
void Delete(int index){
assert(0 <= index && index < size());
std::pair<SplayNode*,SplayNode*> tmp = Delete_sub(index,m_Root);
m_Root = tmp.first;
if(tmp.second != nullptr)delete tmp.second;
return;
}
// erase 1 element which has key_ as Key
void erase(type_key key_){
int it = find(key_);
if(it!=-1)Delete(it);
return;
}
// erase 1 element which has (key_,value_) as (Key,Value)
void erase_pair(type_key key_ , type_value val_){
assert(_paired);
int it = find_pair(key_ , val_);
if(it!=-1)Delete(it);
return;
}
// counts nodes which < (x)
int lower_bound(type_key x){
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,true,false);
m_Root = tmp.first;
return tmp.second;
}
// counts nodes which < (x,y)
int lower_bound_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,true,true);
m_Root = tmp.first;
return tmp.second;
}
// counts nodes which <= (x)
int upper_bound(type_key x){
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,false,false);
m_Root = tmp.first;
return tmp.second;
}
// counts nodes which <= (x,y)
int upper_bound_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return 0;
std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,false,true);
m_Root = tmp.first;
return tmp.second;
}
// find the index [i] which [i]th node has x as Key (if some answer exist,return smallest)
// if no answer is found, return -1
int find(type_key x){
if(size() == 0)return -1;
if(upper_bound(x) - lower_bound(x) <= 0)return -1;
return lower_bound(x);
}
// find the index [i] which [i]th node has (x,y) as (Key,Value) (if some answer exist,return smallest)
// if no answer is found, return -1
int find_pair(type_key x , type_value y){
assert(_paired);
if(size() == 0)return -1;
if(upper_bound_pair(x,y) - lower_bound_pair(x,y) <= 0)return -1;
return lower_bound_pair(x,y);
}
SplayNode back(){assert(size()>0);return get(size()-1);}
SplayNode front(){assert(size()>0);return get(0);}
void pop_back(){assert(size()>0);Delete(size()-1);}
void pop_front(){assert(size()>0);Delete(0);}
SplayNode operator [](int index){return get(index);}
};
使用例 1 ( ABC281-E )
実行時間 379ms ( TL : 2000ms )
挿入,削除,アクセス全てが $O(\log{N})$ 時間です。ある範囲の Key
の総和を取得して答えることもできます。
#include<iostream>
#include "MyMultiSet.hpp"
using std::cout ,std::endl , std::cin;
// ABC281-E (https://atcoder.jp/contests/abc281/tasks/abc281_e)
int main(){
int n , m , k;cin >> n >> m >> k;
long long A[200002];
MyMultiSet<long long,long long> S;
for(int i = 0 ; i < n ; i++)cin >> A[i];
for(int i = 0 ; i < m ; i++)S.insert(A[i]);
cout << S.GetRange(0,k).Sum_key << " ";
for(int i = 1 ; i < n-m+1 ; i++ ){
S.erase(A[i-1]);
S.insert(A[i+m-1]);
cout << S.GetRange(0,k).Sum_key << " ";
}
cout << endl;
return 0;
}
使用例 2 (Library Checker - Predecessor Problem)
実行時間 5355ms ( TL: 10000ms )
std::set
ほど速くはありませんが、
- $1\leq N\leq10^7$
- $1\leq Q \leq 10^6$
の制約でも $5$ 秒程度で終了します。
#include<iostream>
#include<string>
#include "MyMultiSet.hpp"
//Library Checker - Predecessor Problem (https://judge.yosupo.jp/problem/predecessor_problem)
int main(){
int n,q;cin >> n >> q;
string t;cin >> t;
MyMultiSet<int,int> S;
for(int i = 0 ; i < t.size() ; i++){
if(t[i] != '0')S.insert(i);
}
while(q-->0){
int qt;cin >> qt;
int k;cin >> k;
if(qt == 0){
if(S.find(k) == -1)S.insert(k);
}else if(qt == 1)S.erase(k);
else if(qt == 2)cout << int(S.find(k) != -1) << endl;
else if(qt == 3){
int it = S.lower_bound(k);
if(it >= S.size())cout << -1 << endl;
else cout << S[it].Key << endl;
}else{
int it = S.upper_bound(k)-1;
if(it < 0) cout << -1 << endl;
else cout << S[it].Key << endl;
}
}
return 0;
}
使用例 3 ( ABC367-D )
実行時間 1102ms ( TL: 2000ms )
Value
に値を入れておけば、(Key,Value)
の辞書順で要素を調べることができる。Value
の集約も計算できる。詳細は後述。
#include<iostream>
#include<vector>
using std::cout,std::cin,std::endl;
using std::vector;
// ABC367-D(https://atcoder.jp/contests/abc367/tasks/abc367_d)
int main(){
int n , m;cin >> n >> m;
vector<long long> a(n) , r(n+1,0);
for(long long & x : a)cin >> x;
for(int i = 0 ; i < n ; i++)r[i+1] = r[i]+a[i];
MyMultiSet<long long , long long> S;
for(int i = 0 ; i <= n ; i++ )S.insert_pair(r[i]%m,-n + i);
long long accum = r[n];
long long ans = 0;
for(int i = 1 ; i <= n ; i++){
accum += a[i-1];
accum%=m;
ans += S.upper_bound(accum) - S.lower_bound_pair(accum,i-n+1);
S.insert_pair(accum,i);
}
cout << ans << endl;
return 0;
}
概要
type_key
型の Key
と type_value
型の Value
を持つ順序付き集合で、(Key , Value) が辞書順にソートされている。ただし、Value を無視する場合は辞書順ではなく、Key
の順序でソートされる。
- C++ の
std::set
とは異なり、Key
の重複を許す (Value
も当然重複 OK )。 -
get(i)
や[i]
でi
番目のノードのコピーを 0-index で取得。ただし隣接頂点へのアクセス (ポインタ) が封印されたものを返す。 -
Delete(i)
で小さい順でi
番目の要素を削除する。 -
GetRange(l,r)
は要素の辞書順の半開区間[l,r)
をカバーする部分木の根のコピーを返す。get()
同様に、隣接頂点のポインタは封印されている。-
GetRange(l,r).Sum_val
のようにして[l,r)
の持つ要素のモノイド積を取得する
-
(Key,Value) に関して以下の操作が可能
-
insert_pair(k,v)
-
(k,v)
を持つノードを追加
-
-
erase_pair(k,v)
-
(k,v)
を持つノードを(存在すれば)削除する。
-
-
upper_bound_pair(k,v)
-
(Key,Value)
が辞書順で(k,v)
以下の要素数を返す
-
-
lower_bound_pair(k,v)
-
(Key,Value)
が辞書順で(k,v)
未満の要素数を返す
-
-
find_pair(k,v)
-
(Key,Value)
が(k,v)
である要素の index を返す。存在しなければ-1
を返す (0-index)。
-
ノードの Key だけに注目して、通常の set のように振る舞わせることもできる。
-
insert(k)
-
Key = k
である要素を追加する。ただし、Value
を指定しないのでValue
は未定義とする。
-
-
erase(k)
-
Key = k
である要素を一つ削除する。ただし、Value
について特に指定しないことに注意。
-
-
upper_bound(k)
-
Key
がk
以下の要素数を返す
-
-
lower_bound(k)
-
Key
がk
未満の要素数を返す
-
-
find(k)
-
Key
がk
である要素の index を返す。存在しなければ-1
を返す (0-index)。
-
ペアを持つノードとKeyしか持たないノードが混在するといけない
-
insert()
を呼び出した時点で、Value
が未定義の要素が存在することになる-
insert()
を呼び出した後はupper_bound_pair
,find_pair
,RangeValueMaxQuery
など、Value
に関する関数を呼び出すとランタイムエラーになるようにしました。(ただし、Key
だけに関係する関数は変わらず使用できる。
-
注意点
コピーを禁止しているので、vector
など STL
コンテナに乗せるのは非推奨ですが、競プロでの実用で問題になることは珍しいかも。ただし、コピーを回避して実装しないとダメ。
初期化値を与えて初期化はダメ
// これは NG !!!
vector<MyMultiSet<int,int>> SV(100 , MyMultiSet<int,int>());
デフォルトコンストラクタの使用を指示しよう
// これは OK !!!
vector<MyMultiSet<int,int>> SV(100);
その他
- CompareNode() を変更することで要素の並び順を変更できる。
- 必要があれば
eval()
に遅延評価を実装して良い。- 順序付きなので、評価した後に順序が崩れないように制約をつける必要がある。
高速化手段
-
SplayNode
のupdate
の、不要な集約の記述を消す (SubTreeSize
は絶対必要)。 - デストラクタを消す (競プロ文脈限定)。
- など