21
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

自作 Set ライブラリ提供 : C++ の std::set が残念な件

Last updated at Posted at 2024-08-21

はじめに

C++std::set は素晴らしいライブラリです。自前でこれを実装しようとしてもなかなかここまで速いものは作れないでしょう。

std::set の代替手段として tatyam さんのライブラリが有名ですが、これは std::set がない Python のためのやつなので、Python ユーザー限定です。

Python に順序付き集合がないことはよく話題に上がることなのですが、C++ の順序付き集合である std::set も競プロをする上ではなかなか残念なので、C++ 用の順序付き集合ライブラリを公開すればみなさん嬉しいかもと思い、公開するに至りました。

競技プログラミングにおける活用

std::set にもつけ入る隙があります。

  • 要素アクセスが線形時間
  • 集約や遅延評価が載っていない
  • 値の重複を許さない(これは std::map と併用するなどで解決できますが)

これらを完全に補完したデータ構造を作りましたので、よければどうぞ。ただし、ライセンスは守ってくださいね。

コードは こちらの GitHub リポジトリ でも公開しています。データ構造の概要は こちらの記事 (平衡二分木入門 : Splay Tree) で説明しています。

自分で言うのもなんですが、超絶便利です。C++std::set,std::multiset,std::map を完全にカバーできております。

ぼくの自作の多重集合 MyMultiSet は $2$ つの要素 (Key,Value) をもち、(Key,Value) の辞書順で管理しています (例外アリ、後述)。これによって Key や index の範囲を絞って、その内の KeyValue の集約 (総和,Min,Max) を取得できます。

ソースコード

使い方は後述します。あとコメントの英語が拙いのは気にしないでください。雰囲気がわかればいいのです。

ライセンスは守ってくださいね (このライブラリの冒頭に以下を書くだけで良いです)。

ライセンスに従わない利用を確認した場合、ライセンスに従わない利用をやめるよう指示するなどの処置を取る場合があります。

#include<iostream>
#include<cassert>

/*
    このコメントは消さないでください。
    Don't Remove this Comment !!

    Copyright ©️ (c) NokonoKotlin (okoteiyu) 2024.
    Released under the MIT license(https://opensource.org/licenses/mit-license.php)
*/
template<class type_key , class type_value>
class MyMultiSet{
    private:
    struct SplayNode{
        SplayNode *parent = nullptr;// parent node
        SplayNode *left = nullptr;// left child
        SplayNode *right = nullptr;// left child
 
        type_key Key;// sorted key
        type_value Value;// value (sorted if key is same)
 
        type_key Sum_key;// Sum of Key in Subtree
        type_value Min_val,Max_val,Sum_val;// Min,Max,Sum of Value in Subtree 

        int SubTreeSize = 1;// Size of Subtree under this node

        private:
        bool copied_instance = false;
        public:
        SplayNode copy(){
            assert(copied_instance == false);
            SplayNode res = *this;
            res.left = nullptr;
            res.right = nullptr;
            res.parent = nullptr;
            res.copied_instance = true;
            return res;
        }

        SplayNode(){}
        SplayNode(type_key key_ , type_value val_){
            Key = key_;
            Value = val_;
            update();
        }


        // rotate ( this node - parent ) 
        void rotate(){
            if(this->parent->parent){
                if(this->parent == this->parent->parent->left)this->parent->parent->left = this;
                else this->parent->parent->right = this; 
            }

            this->parent->eval();
            this->eval();
            
            if(this->parent->left == this){
                this->parent->left = this->right;
                if(this->right)this->right->parent = this->parent;
                this->right = this->parent;
                this->parent = this->right->parent;
                this->right->parent = this;
                this->right->update();
            }else{
                this->parent->right = this->left;
                if(this->left)this->left->parent = this->parent;
                this->left = this->parent;
                this->parent = this->left->parent;
                this->left->parent = this;
                this->left->update();
            }

            this->update();
            return;
        }

        // direction of this parent (left or right)
        int state(){
            if(this->parent == nullptr)return 0;
            this->eval();
            if(this->parent->left == this)return 1;
            else if(this->parent->right == this)return 2;
            return 0;
        }

        // bottom-up splay 
        void splay(){
            while(bool(this->parent)){
                if(this->parent->state() == 0){
                    this->rotate();
                    break;
                }
                if( this->parent->state() == this->state() )this->parent->rotate();
                else this->rotate();
                this->rotate();
            }
            this->update();
            return;
        }

        // update data member
        void update(){
            assert(copied_instance == false);

            this->eval();
            this->SubTreeSize = 1;
            this->Sum_key = this->Key;
            this->Max_val = this->Sum_val = this->Min_val = this->Value;
            
            // add left child
            if(bool(this->left)){
                this->left->eval();
                this->SubTreeSize += this->left->SubTreeSize;
                if(this->left->Min_val < this->Min_val)this->Min_val = this->left->Min_val;
                if(this->left->Max_val > this->Max_val)this->Max_val = this->left->Max_val;
                this->Sum_key += this->left->Sum_key;
                this->Sum_val += this->left->Sum_val;
            }

            // add right child
            if(bool(this->right)){
                this->right->eval();
                this->SubTreeSize += this->right->SubTreeSize;
                if(this->right->Min_val < this->Min_val)this->Min_val = this->right->Min_val;
                if(this->right->Max_val > this->Max_val)this->Max_val = this->right->Max_val;
                this->Sum_key += this->right->Sum_key;
                this->Sum_val += this->right->Sum_val;
            }
 
            return;
        }

        // evaluate Lazy Evaluation
        void eval(){
            // if it's necessary , write here.
            assert(copied_instance == false);
        }
    };


    /*
        1. order of node's Key if [paired_compare] is false.
        2. lexicographical order of node's (Key ,Value) if [paired_compare] is true.
    */
    constexpr bool CompareNode(SplayNode *a , SplayNode *b , bool paired_compare = false){
        a->eval();
        b->eval();
        if(!paired_compare)return a->Key <= b->Key;
        else{
            if(a->Key < b->Key)return true;
            else if(a->Key == b->Key){
                if(a->Value <= b-> Value)return true;
                else return false;
            }else return false;
        }
    }
    
    // get [index]th node pointer 
    SplayNode *get_sub(int index , SplayNode *root){
        if(root==nullptr)return root;
        SplayNode *now = root;
        while(true){
            now->eval();
            int left_size = 0;
            if(now->left != nullptr)left_size = now->left->SubTreeSize;
            if(index < left_size)now = now->left;
            else if(index > left_size){
                now = now->right;
                index -= left_size+1;
            }else break;
        }
        now->splay();
        return now;
    }
 
    // merge 2 SplayTrees 
    SplayNode *merge(SplayNode *leftRoot , SplayNode *rightRoot){
        if(leftRoot!=nullptr)leftRoot->update();
        if(rightRoot!=nullptr)rightRoot->update();
        if(bool(leftRoot ) == false)return rightRoot;
        if(bool(rightRoot) == false )return leftRoot;
        rightRoot = get_sub(0,rightRoot);
        rightRoot->left = leftRoot;
        leftRoot->parent = rightRoot;
        rightRoot->update();
        return rightRoot;
    }
    
 
    // split SplayTree at [leftnum]
    std::pair<SplayNode*,SplayNode*> split(int leftnum, SplayNode *root){
        if(leftnum<=0)return std::make_pair(nullptr , root);
        if(leftnum >= root->SubTreeSize)return std::make_pair(root, nullptr);
        root = get_sub(leftnum , root);
        SplayNode *leftRoot = root->left;
        SplayNode *rightRoot = root;
        if(bool(rightRoot))rightRoot->left = nullptr;
        if(bool(leftRoot))leftRoot->parent = nullptr;
        leftRoot->update();
        rightRoot->update();
        return std::make_pair(leftRoot,rightRoot);
    }
 
    
    
    // remove [index]th node
    std::pair<SplayNode*,SplayNode*> Delete_sub(int index , SplayNode *root){
        if(bool(root) == false)return std::make_pair(root,root);
        root = get_sub(index,root);
        SplayNode *leftRoot = root->left;
        SplayNode *rightRoot = root->right;
        if(bool(leftRoot))leftRoot->parent = nullptr;
        if(bool(rightRoot))rightRoot->parent = nullptr;
        root->left = nullptr;
        root->right = nullptr;
        root->update();
        return std::make_pair(merge(leftRoot,rightRoot) , root );
    }
    
    
    /*
        between 2 SplayNodes [A] and [B] , we define following order.
        - if [paired_compare] is false, 
            - [A] [<] [B] represent a order of these Keys.
            - [A] [==] [B] represent these Keys are same
        - if [paired_compare] is true, 
            - [A] [<] [B] represent a lexicographical order of these (Key , Value).
            - [A] [==] [B] represent these (Key , Value) are same

        This function finds the border index [B] which satisfies following.
        1. if [lower] is true, for any [i] smaller than [B] , {[i]th node} [<] {[Node] argument}
        2. if [lower] is false, for any [i] smaller than [B] , {[i]th node} [<] {[Node] argument} or  {[i]th node} [==] {[Node] argument}
    */
    std::pair<SplayNode*,int> bound_sub(SplayNode* Node , SplayNode *root , bool lower ,  bool paired_compare ){
        if(bool(root) == false)return std::make_pair(root,0);
        SplayNode *now = root;
        int res = 0;
        Node->update();
        while(true){
            now->eval();
            bool satisfy = CompareNode(now,Node,paired_compare); // upper_bound (now <= Node)
            if(lower)satisfy = !CompareNode(Node,now,paired_compare); // lower_bound (now < Node)
            if(satisfy){
                if(bool(now->right))now = now->right;
                else {
                    res++;
                    break;
                }
            }else{
                if(bool(now->left))now = now->left;
                else break;
            }
        }
        now->splay();
        if(bool(now->left))res += now->left->SubTreeSize;
        return std::make_pair(now ,res);
    }
    
    // insert [NODE]argument into SplayTree (in which nodes are sorted)
    SplayNode *insert_sub(SplayNode *NODE , SplayNode *root , bool paired_compare){
        NODE->update();
        if(bool(root) == false)return NODE;
        // find the border index [x] ( [x]th node [<] [NODE] argument ]
        root = bound_sub(NODE,root,true,paired_compare).first;
        root->eval();
        if(!CompareNode(NODE , root , paired_compare)){
            if(root->right != nullptr)root->right->parent = NODE;
            NODE->right = root->right;
            root->right = nullptr;
            NODE->left = root;
        }else{
            if(root->left != nullptr)root->left->parent = NODE;
            NODE->left = root->left;
            root->left = nullptr;
            NODE->right = root;
        }
        root->parent = NODE;
        root->update();
        NODE->update();
        return NODE;
    }
     
    protected:

    // root node of this tree
    SplayNode *m_Root = nullptr;

    // bluff node object (used as temporary node)
    SplayNode *m_bluff_object = nullptr;
    SplayNode* BluffObject(type_key k , type_value v){
        if(m_bluff_object == nullptr)m_bluff_object = new SplayNode(type_key(0),type_value(0));
        m_bluff_object->Key = k;
        m_bluff_object->Value = v;
        return m_bluff_object;
    }

    // flag of whether node's Values are defined
    // (Values might be undefined if we use insert() function)
    bool _paired = true; 

    void release(){while(m_Root != nullptr)this->Delete(0);}

    void init(){
        m_Root = nullptr;
        _paired = true; 
    }
    
    // pointer of leftmost node
    const SplayNode* const begin(){
        if(size() == 0)return nullptr;
        m_Root = get_sub(0,m_Root);
        return m_Root;
    }

    public:

    MyMultiSet(){init();}
    ~MyMultiSet(){release();}
    // don't copy this object
    MyMultiSet(const MyMultiSet<type_key,type_value> & x) = delete ;
    MyMultiSet& operator = ( const MyMultiSet<type_key,type_value> & x) = delete ;
    MyMultiSet ( MyMultiSet<type_key,type_value>&& x){assert(0);}
    MyMultiSet& operator = ( MyMultiSet<type_key,type_value>&& x){assert(0);}
    // this function makes whole new SplayTree object from [x] argument
    void copy(MyMultiSet<type_key,type_value>& x){
        if(this->begin() == x.begin())return;
        release();
        init();
        for(int i=0;i<x.size();i++){
            SplayNode t=x.get(i);
            this->insert_pair(t.Key,t.Value);
        }
        this->_paired = x._paired;
    }
    
    // tree size
    int size(){
        if(m_Root == nullptr)return 0;
        return m_Root->SubTreeSize;
    }

    // get copy object of [i]th node 
    SplayNode get(int i){
        assert(0 <= i && i < size());
        m_Root = get_sub(i,m_Root);
        return m_Root->copy();
    }

    // get copy object node which covers interval [l,r)
    // Ex. we can get Sum of Key in [l,r)
    SplayNode GetRange(int l , int r){
        assert(0 <= l && l < r && r <= size());
        std::pair<SplayNode*,SplayNode*> tmp = split(r,m_Root);
        SplayNode* rightRoot = tmp.second;
        tmp = split(l,tmp.first);// 部分木を取り出す。
        SplayNode res = tmp.second->copy();
        m_Root = merge(merge(tmp.first,tmp.second),rightRoot);
        return res;
    }
 
    // insert key_ 
    void insert( type_key key_ ){
        _paired = false;// undefined Value was added
        m_Root = insert_sub(new SplayNode(key_,type_value(0)) ,m_Root , false);
        return;
    }
 
    // insert (key_ , value_)
    void insert_pair( type_key key_ , type_value val_){
        assert(_paired);
        m_Root = insert_sub(new SplayNode(key_,val_) ,m_Root,true);
        return;
    }
 
    // delete [index]th element
    void Delete(int index){
        assert(0 <= index && index < size());
        std::pair<SplayNode*,SplayNode*> tmp = Delete_sub(index,m_Root);
        m_Root = tmp.first;
        if(tmp.second != nullptr)delete tmp.second;
        return;
    }

    // erase 1 element which has key_ as Key
    void erase(type_key key_){
        int it = find(key_);
        if(it!=-1)Delete(it);
        return;
    }
 
    // erase 1 element which has (key_,value_) as (Key,Value)
    void erase_pair(type_key key_ , type_value val_){
        assert(_paired);
        int it = find_pair(key_ , val_);
        if(it!=-1)Delete(it);
        return;
    }
 
    // counts nodes which < (x)
    int lower_bound(type_key x){
        if(size() == 0)return 0;
        std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,true,false);
        m_Root = tmp.first;
        return tmp.second;
    }

    // counts nodes which < (x,y)
    int lower_bound_pair(type_key x , type_value y){
        assert(_paired);
        if(size() == 0)return 0;
        std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,true,true);
        m_Root = tmp.first;
        return tmp.second;
    }
 
    // counts nodes which <= (x)
    int upper_bound(type_key x){
        if(size() == 0)return 0;
        std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,type_value(0)),m_Root,false,false);
        m_Root = tmp.first;
        return tmp.second;
    }
 
    // counts nodes which <= (x,y)
    int upper_bound_pair(type_key x , type_value y){
        assert(_paired);
        if(size() == 0)return 0;
        std::pair<SplayNode*,int> tmp = bound_sub(BluffObject(x,y),m_Root,false,true);
        m_Root = tmp.first;
        return tmp.second;
    }
    
    // find the index [i] which [i]th node has x as Key (if some answer exist,return smallest)
    // if no answer is found, return -1
    int find(type_key x){
        if(size() == 0)return -1;
        if(upper_bound(x) - lower_bound(x) <= 0)return -1;
        return lower_bound(x);
    }

    // find the index [i] which [i]th node has (x,y) as (Key,Value) (if some answer exist,return smallest)
    // if no answer is found, return -1
    int find_pair(type_key x , type_value y){
        assert(_paired);
        if(size() == 0)return -1;
        if(upper_bound_pair(x,y) - lower_bound_pair(x,y) <= 0)return -1;
        return lower_bound_pair(x,y);
    }


    SplayNode back(){assert(size()>0);return get(size()-1);}
    SplayNode front(){assert(size()>0);return get(0);}
    void pop_back(){assert(size()>0);Delete(size()-1);}
    void pop_front(){assert(size()>0);Delete(0);}
    SplayNode operator [](int index){return get(index);}               
};

使用例 1 ( ABC281-E )

実行時間 379ms ( TL : 2000ms )

挿入,削除,アクセス全てが $O(\log{N})$ 時間です。ある範囲の Key の総和を取得して答えることもできます。

#include<iostream>
#include "MyMultiSet.hpp"

using std::cout ,std::endl , std::cin;

// ABC281-E (https://atcoder.jp/contests/abc281/tasks/abc281_e)
int main(){
    int n , m , k;cin >> n >> m >> k;
    long long A[200002];
    MyMultiSet<long long,long long> S;
    for(int i = 0 ; i < n ; i++)cin >> A[i];

    for(int i = 0 ; i < m ; i++)S.insert(A[i]);
    
    cout << S.GetRange(0,k).Sum_key << " ";
    for(int i = 1 ; i < n-m+1 ; i++ ){
        S.erase(A[i-1]);
        S.insert(A[i+m-1]);
        cout << S.GetRange(0,k).Sum_key << " ";
    }
    cout << endl;
    return 0;
}

使用例 2 (Library Checker - Predecessor Problem)

実行時間 5355ms ( TL: 10000ms )

std::set ほど速くはありませんが、

  • $1\leq N\leq10^7$
  • $1\leq Q \leq 10^6$

の制約でも $5$ 秒程度で終了します。

#include<iostream>
#include<string>
#include "MyMultiSet.hpp"

//Library Checker - Predecessor Problem (https://judge.yosupo.jp/problem/predecessor_problem) 
int main(){
    int n,q;cin >> n >> q;
    string t;cin >> t;
    MyMultiSet<int,int> S;
    for(int i = 0 ; i < t.size() ; i++){
        if(t[i] != '0')S.insert(i);
    }
    while(q-->0){
        int qt;cin >> qt;
        int k;cin >> k;
        if(qt == 0){
            if(S.find(k) == -1)S.insert(k);
        }else if(qt == 1)S.erase(k);
        else if(qt == 2)cout << int(S.find(k) != -1) << endl;
        else if(qt == 3){
            int it = S.lower_bound(k);
            if(it >= S.size())cout << -1 << endl;
            else cout << S[it].Key << endl;
        }else{
            int it = S.upper_bound(k)-1;
            if(it < 0) cout << -1 << endl;
            else cout << S[it].Key << endl;
        }
    }
    return 0;
}

使用例 3 ( ABC367-D )

実行時間 1102ms ( TL: 2000ms )

Value に値を入れておけば、(Key,Value) の辞書順で要素を調べることができる。Value の集約も計算できる。詳細は後述。

#include<iostream>
#include<vector>

using std::cout,std::cin,std::endl;
using std::vector;

// ABC367-D(https://atcoder.jp/contests/abc367/tasks/abc367_d)
int main(){
    int n , m;cin >> n >> m;
    vector<long long> a(n) , r(n+1,0);
    for(long long & x : a)cin >> x;

    for(int i = 0 ; i < n ; i++)r[i+1] = r[i]+a[i];
    MyMultiSet<long long , long long> S;
    for(int i = 0 ; i <= n ; i++ )S.insert_pair(r[i]%m,-n + i);
    long long accum = r[n];
    long long ans = 0;
    for(int i = 1 ; i <= n ; i++){
        accum += a[i-1];
        accum%=m;
        ans += S.upper_bound(accum) - S.lower_bound_pair(accum,i-n+1);
        S.insert_pair(accum,i);
    }
    cout << ans << endl;

    return 0;
}

概要

type_key 型の Keytype_value 型の Value を持つ順序付き集合で、(Key , Value) が辞書順にソートされている。ただし、Value を無視する場合は辞書順ではなく、Key の順序でソートされる。

  • C++ の std::set とは異なり、Key の重複を許す ( Value も当然重複 OK )。
  • get(i)[i]i 番目のノードのコピーを 0-index で取得。ただし隣接頂点へのアクセス (ポインタ) が封印されたものを返す。
  • Delete(i) で小さい順で i 番目の要素を削除する。
  • GetRange(l,r) は要素の辞書順の半開区間 [l,r) をカバーする部分木の根のコピーを返す。get() 同様に、隣接頂点のポインタは封印されている。
    • GetRange(l,r).Sum_val のようにして [l,r) の持つ要素のモノイド積を取得する

(Key,Value) に関して以下の操作が可能

  • insert_pair(k,v)
    • (k,v) を持つノードを追加
  • erase_pair(k,v)
    • (k,v) を持つノードを(存在すれば)削除する。
  • upper_bound_pair(k,v)
    • (Key,Value) が辞書順で (k,v) 以下の要素数を返す
  • lower_bound_pair(k,v)
    • (Key,Value) が辞書順で (k,v) 未満の要素数を返す
  • find_pair(k,v)
    • (Key,Value)(k,v) である要素の index を返す。存在しなければ -1 を返す (0-index)。

ノードの Key だけに注目して、通常の set のように振る舞わせることもできる。

  • insert(k)
    • Key = k である要素を追加する。ただし、Value を指定しないので Value は未定義とする。
  • erase(k)
    • Key = k である要素を一つ削除する。ただし、Value について特に指定しないことに注意。
  • upper_bound(k)
    • Keyk 以下の要素数を返す
  • lower_bound(k)
    • Keyk 未満の要素数を返す
  • find(k)
    • Keyk である要素の index を返す。存在しなければ -1 を返す (0-index)。

ペアを持つノードとKeyしか持たないノードが混在するといけない

  • insert() を呼び出した時点で、Value が未定義の要素が存在することになる
    • insert() を呼び出した後は upper_bound_pair , find_pair , RangeValueMaxQuery など、Value に関する関数を呼び出すとランタイムエラーになるようにしました。(ただし、Key だけに関係する関数は変わらず使用できる。

注意点

コピーを禁止しているので、vector など STL コンテナに乗せるのは非推奨ですが、競プロでの実用で問題になることは珍しいかも。ただし、コピーを回避して実装しないとダメ。

初期化値を与えて初期化はダメ
// これは NG !!!
vector<MyMultiSet<int,int>> SV(100 , MyMultiSet<int,int>());
デフォルトコンストラクタの使用を指示しよう
// これは OK !!!
vector<MyMultiSet<int,int>> SV(100);

その他

  • CompareNode() を変更することで要素の並び順を変更できる。
  • 必要があれば eval() に遅延評価を実装して良い。
    • 順序付きなので、評価した後に順序が崩れないように制約をつける必要がある。

高速化手段

  • SplayNodeupdate の、不要な集約の記述を消す (SubTreeSize は絶対必要)。
  • デストラクタを消す (競プロ文脈限定)。
  • など
21
19
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?