11
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pythonで平衡二分木を使った辞書を作ってみた(C++ std::map相当)

Last updated at Posted at 2021-08-15

はじめに

pythonのset(集合)やdict(辞書)は値の検索削除がO(1)と高速ですが、要素の大きさに順序を持つことができず、順序を気にしながらの実装を愚直にすると計算時間が大きくかかってしまうなどの課題があります。一方でc++では、順序付き集合のsetや順序付きの連想配列のmap(辞書相当)が標準で整備されており、競技プログラミングなど時間制約がタイトな場合では、標準ライブラリーの差でもpython勢は不利になることが多いです。そこで今回はc++のmapに相当する平衡二分木を用いた辞書の実装例を紹介したいと思います。

参考記事

今回は平衡二分木部分に関しては出来合いのものを用いました。
Pythonで非再帰AVL木 - 競プロ記録 他
(非再帰での実装となっておりpythonでも高速に動くためとても重宝しております)
こちらで紹介されている記事の実装例にいくつか機能を付与する形で実装しました。

コードの紹介

コードがとても長くなってしまっているので、関数及び使用方法について次項で重点的に説明しようと思います。

コードの表示はこちらをクリック
AVLTree_dict.py
#非再帰の平衡二分木
import copy
class Node:
    """ノード

    Attributes:
        key (any): ノードのキー。比較可能なものであれば良い。(1, 4)などタプルも可。
        val (any): ノードの値。
        lch (Node): 左の子ノード。
        rch (Node): 右の子ノード。
        bias (int): 平衡度。(左部分木の高さ)-(右部分木の高さ)。
        size (int): 自分を根とする部分木の大きさ

    """

    def __init__(self, key, val):
        self.key = key
        self.val = val
        self.lch = None
        self.rch = None
        self.bias = 0
        self.size = 1

class AVLTree:
    """非再帰AVL木

    Attributes:
        root (Node): 根ノード。初期値はNone。
        valdefault (any): ノード値のデフォルト値。デフォルトではNone。(数値、リストなど可)

    """

    def __init__(self,valdefault=None):
        self.root = None
        self.valdefault = valdefault
        
    def rotate_left(self, v):
        u = v.rch
        u.size = v.size
        v.size -= u.rch.size + 1 if u.rch is not None else 1
        v.rch = u.lch
        u.lch = v
        if u.bias == -1:
            u.bias = v.bias = 0
        else:
            u.bias = 1
            v.bias = -1
        return u

    def rotate_right(self, v):
        u = v.lch
        u.size = v.size
        v.size -= u.lch.size + 1 if u.lch is not None else 1
        v.lch = u.rch
        u.rch = v
        if u.bias == 1:
            u.bias = v.bias = 0
        else:
            u.bias = -1
            v.bias = 1
        return u

    def rotateLR(self, v):
        u = v.lch
        t = u.rch
        t.size = v.size
        v.size -= u.size - (t.rch.size if t.rch is not None else 0)
        u.size -= t.rch.size + 1 if t.rch is not None else 1
        u.rch = t.lch
        t.lch = u
        v.lch = t.rch
        t.rch = v
        self.update_bias_double(t)
        return t

    def rotateRL(self, v):
        u = v.rch
        t = u.lch
        t.size = v.size
        v.size -= u.size - (t.lch.size if t.lch is not None else 0)
        u.size -= t.lch.size + 1 if t.lch is not None else 1
        u.lch = t.rch
        t.rch = u
        v.rch = t.lch
        t.lch = v
        self.update_bias_double(t)
        return t

    def update_bias_double(self, v):
        if v.bias == 1:
            v.rch.bias = -1
            v.lch.bias = 0
        elif v.bias == -1:
            v.rch.bias = 0
            v.lch.bias = 1
        else:
            v.rch.bias = 0
            v.lch.bias = 0
        v.bias = 0

    def insert(self, key, val=None):
        """挿入

        指定したkeyを挿入する。valはkeyのノード値。

        Args:
            key (any): キー。
            val (any): 値。(指定しない場合はvaldefaultが入る)

        Note:
            同じキーがあった場合は上書きする。

        """
        if val == None:
            val = copy.deepcopy(self.valdefault)
        if self.root is None:
            self.root = Node(key, val)
            return

        v = self.root
        history = []
        while v is not None:
            if key < v.key:
                history.append((v, 1))
                v = v.lch
            elif v.key < key:
                history.append((v, -1))
                v = v.rch
            elif v.key == key:
                v.val = val
                return

        p, pdir = history[-1]
        if pdir == 1:
            p.lch = Node(key, val)
        else:
            p.rch = Node(key, val)

        while history:
            v, direction = history.pop()
            v.bias += direction
            v.size += 1

            new_v = None
            b = v.bias
            if b == 0:
                break

            if b == 2:
                u = v.lch
                if u.bias == -1:
                    new_v = self.rotateLR(v)
                else:
                    new_v = self.rotate_right(v)
                break
            if b == -2:
                u = v.rch
                if u.bias == 1:
                    new_v = self.rotateRL(v)
                else:
                    new_v = self.rotate_left(v)
                break

        if new_v is not None:
            if len(history) == 0:
                self.root = new_v
                return
            p, pdir = history.pop()
            p.size += 1
            if pdir == 1:
                p.lch = new_v
            else:
                p.rch = new_v

        while history:
            p, pdir = history.pop()
            p.size += 1

    def delete(self, key):
        """削除

        指定したkeyの要素を削除する。

        Args:
            key (any): キー。

        Return:
            bool: 指定したキーが存在するならTrue、しないならFalse。

        """
        v = self.root
        history = []
        while v is not None:
            if key < v.key:
                history.append((v, 1))
                v = v.lch
            elif v.key < key:
                history.append((v, -1))
                v = v.rch
            else:
                break
        else:
            return False

        if v.lch is not None:
            history.append((v, 1))
            lmax = v.lch
            while lmax.rch is not None:
                history.append((lmax, -1))
                lmax = lmax.rch

            v.key = lmax.key
            v.val = lmax.val

            v = lmax

        c = v.rch if v.lch is None else v.lch

        if history:
            p, pdir = history[-1]
            if pdir == 1:
                p.lch = c
            else:
                p.rch = c
        else:
            self.root = c
            return True

        while history:
            new_p = None

            p, pdir = history.pop()
            p.bias -= pdir
            p.size -= 1

            b = p.bias
            if b == 2:
                if p.lch.bias == -1:
                    new_p = self.rotateLR(p)
                else:
                    new_p = self.rotate_right(p)

            elif b == -2:
                if p.rch.bias == 1:
                    new_p = self.rotateRL(p)
                else:
                    new_p = self.rotate_left(p)

            elif b != 0:
                break

            if new_p is not None:
                if len(history) == 0:
                    self.root = new_p
                    return True
                gp, gpdir = history[-1]
                if gpdir == 1:
                    gp.lch = new_p
                else:
                    gp.rch = new_p
                if new_p.bias != 0:
                    break

        while history:
            p, pdir = history.pop()
            p.size -= 1

        return True

    def member(self, key):
        """キーの存在チェック

        指定したkeyがあるかどうか判定する。

        Args:
            key (any): キー。

        Return:
            bool: 指定したキーが存在するならTrue、しないならFalse。

        """
        v = self.root
        while v is not None:
            if key < v.key:
                v = v.lch
            elif v.key < key:
                v = v.rch
            else:
                return True
        return False

    def getval(self, key):
        """値の取り出し

        指定したkeyの値を返す。

        Args:
            key (any): キー。

        Return:
            any: 指定したキーが存在するならそのオブジェクト。存在しなければvaldefault

        """
        v = self.root
        while v is not None:
            if key < v.key:
                v = v.lch
            elif v.key < key:
                v = v.rch
            else:
                return v.val
        self.insert(key)
        return self[key] #

    def lower_bound(self, key):
        """下限つき探索

        指定したkey以上で最小のキーを見つける。[key,inf)で最小

        Args:
            key (any): キーの下限。

        Return:
            any: 条件を満たすようなキー。そのようなキーが一つも存在しないならNone。

        """
        ret = None
        v = self.root
        while v is not None:
            if v.key >= key:
                if ret is None or ret > v.key:
                    ret = v.key
                v = v.lch
            else:
                v = v.rch
        return ret

    def upper_bound(self, key):
        """上限つき探索

        指定したkey未満で最大のキーを見つける。[-inf,key)で最大

        Args:
            key (any): キーの上限。

        Return:
            any: 条件を満たすようなキー。そのようなキーが一つも存在しないならNone。

        """
        ret = None
        v = self.root
        while v is not None:
            if v.key < key:
                if ret is None or ret < v.key:
                    ret = v.key
                v = v.rch
            else:
                v = v.lch
        return ret

    def find_kth_element(self, k):
        """小さい方からk番目の要素を見つける

        Args:
            k (int): 何番目の要素か(0オリジン)。

        Return:
            any: 小さい方からk番目のキーの値。
        """
        v = self.root
        s = 0
        while v is not None:
            t = s+v.lch.size if v.lch is not None else s
            if t == k:
                return v.key
            elif t < k:
                s = t+1
                v = v.rch
            else:
                v = v.lch
        return None
    
    def getmin(self):
        '''
        Return:
            any: 存在するキーの最小値
        '''
        if len(self) == 0:
            raise('empty')
        ret = None
        v = self.root
        while True:
            ret = v
            v = v.lch
            if v == None:
                break
        return ret.key
    
    def getmax(self):
        '''
        Return:
            any: 存在するキーの最大値
        '''
        if len(self) == 0:
            raise('empty')
        ret = None
        v = self.root
        while True:
            ret = v
            v = v.rch
            if v == None:
                break
        return ret.key
    
    def popmin(self):
        '''
        存在するキーの最小値をpopする
        Return:
            any: popした値
        '''
        if len(self) == 0:
            raise('empty')
        ret = None
        v = self.root
        while True:
            ret = v
            v = v.lch
            if v == None:
                break      
        del self[ret.key]
        return ret.key
    
    def popmax(self):
        '''
        存在するキーの最大値をpopする
        Return:
            any: popした値
        '''
        if len(self) == 0:
            raise('empty')
        ret = None
        v = self.root
        while True:
            ret = v
            v = v.rch
            if v == None:
                break
        del self[ret.key]
        return ret.key

    def popkth(self,k):
        '''
        存在するキーの小さい方からk番目をpopする
        Return:
            any: popした値
        '''
        key = self.find_kth_element(k)
        del self[key]
        return key
    
    def get_key_val(self):
        '''
        Return:
            dict: 存在するキーとノード値をdictで出力
        '''
        retdict = dict()
        for i in range(len(self)):
            key = self.find_kth_element(i)
            val = self[key]
            retdict[key] = val
        return retdict
    
    def values(self):
        for i in range(len(self)):
            yield self[self.find_kth_element(i)]

    def keys(self):
        for i in range(len(self)):
            yield self.find_kth_element(i)
    
    def items(self):
        for i in range(len(self)):
            key = self.find_kth_element(i)
            yield key,self[key]

    def __iter__(self): return self.keys()
    def __contains__(self, key): return self.member(key)
    def __getitem__(self, key): return self.getval(key)
    def __setitem__(self, key, val): return self.insert(key, val)
    def __delitem__(self, key): return self.delete(key)
    def __bool__(self): return self.root is not None
    def __len__(self): return self.root.size if self.root is not None else 0
    def __str__(self): return str(type(self))+'('+str(self.get_key_val())+')'

関数の説明

関数 説明
rotate_left - update_bias_double 平衡二分木を実現するうえで最も大事な部分です(平衡を保つための回転などを行っています)。外部で呼び出す必要はないので、使用上は気にしなくていい部分です。
insert(key,val=None) keyを平衡二分木に挿入します。valはkeyのノード値(value)です(辞書のkey,valueと同等)
delete(key) 指定したkeyの要素を削除します。
menber(key) 指定したkeyの有無を判定します。
getval(key) 指定したkeyの値のノード値を返します。(dict[key]と同等)
lower_bound(key) [key,inf)で最小のものを返します。
upper_bound(key) [-inf,key)で最大のものを返します。
find_kth_element(k) 小さい方からk番目のkeyを返します(最小を0番目とする)

以下元記事にはなかった追加機能です。

関数 説明
getmin() 最小のkeyを返します。
getmax() 最大のkeyを返します。
popmin() 最小のkeyを返し、削除します。(heapqと同じ操作が実現可能)
popmax() 最大のkeyを返し、削除します。
popkth(k) 小さい方からk番目のkeyを返し、削除します。
get_key_val() 存在するすべてのkeyとノード値(value)をdictで返します。
values() すべてのノード値(value)を返します。(dictのvalues()と同等)
keys() すべてのkeyを返します。(dictのkeys()と同等)
items() すべてのkeyとノード値(value)を返します。(dictのitems()と同等)

計算量はget_key_val(),values(),keys(),items()がO(Nlog(N))、他はO(log(N))と高速です。
その他に、使用感をdictに合わせるために特殊メゾットを定義しています。

使用例

注意点としてはkeyは大小比較可能なのものである必要があるというところです(数値、文字列、tupleなどは可能)

・普通の平衡二分木としての動作

a.py
AVL = AVLTree()
AVL.insert(10)
AVL.insert(20)
AVL.insert(30)
AVL.insert(40)
AVL.insert(50)
print(AVL.lower_bound(15))
# 20
print(AVL.find_kth_element(2))
# 30
print(40 in AVL)
# True
del AVL[40]   # AVL.delete(40)と等価
print(40 in AVL)
# False
print(list(AVL))
# [10, 20, 30, 50]
print(AVL.popmin())
# 10
print(AVL.popkth(1)) # 20,30,50のうち1番目(0オリジン)の30
# 30
print(list(AVL))
# [20, 50]
print(len(AVL))
# 2

・辞書っぽい使い方

b.py
AVL1 = AVLTree()
AVL1['a'] = 'A'
AVL1['b'] = 'B'
AVL1['f'] = 'C'
AVL1['aa'] = 'AA'
print(list(AVL1))
# ['a', 'aa', 'b', 'f']
print(AVL1.get_key_val())
# {'a': 'A', 'aa': 'AA', 'b': 'B', 'f': 'C'}
print(AVL1.getmax())
# f
print(AVL1.upper_bound('e'))
# b

・collections.defaultdict相当の動作

c.py
# はじめにvaldefaultを指定することでdefaultdictに相当する処理が可能。
AVL2 = AVLTree(valdefault=[])
AVL2[20].append(2)
AVL2[20].append(3)
AVL2[20].append(6)
AVL2[30].append(5)
AVL2[40].append(1)
AVL2[40].append(2)
print(AVL2.get_key_val())
# {20: [2, 3, 6], 30: [5], 40: [1, 2]}
print(AVL2[20].pop())
# 6
print(40 in AVL2)
# True
print(50 in AVL2)
# False
print(AVL2.popmax())
# 40
AVL2[50].append(5)
AVL2[50].append(6)
print(AVL2.get_key_val())
# {20: [2, 3], 30: [5], 50: [5, 6]}

・collections.Counter相当の動作

d.py
AVL3 = AVLTree(valdefault=0)
AVL3[30] += 3
AVL3[40] += 2
AVL3[50] += 1
print(AVL3.get_key_val())
# {30: 3, 40: 2, 50: 1}
AVL3[50] += 5
print(AVL3.get_key_val())
# {30: 3, 40: 2, 50: 6}
print(list(AVL3.values()))
# [3, 2, 6]
for key,val in AVL3.items():
    print(key,val)
    # 30 3
    # 40 2
    # 50 6
while AVL3:
    key = AVL3.getmin()
    print(key,AVL3[key])
    AVL3.popmin()
    # 30 3
    # 40 2
    # 50 6

競技プログラミングでの使用例

・atcoder ABC140F 提出結果
・atcoder AGC005B 提出結果
・defaultdictとしての例: atcoder diverta 2019 Programming Contest 2 提出結果
・atcoder ABC217D 提出結果
・atcoder ABC217E 提出結果

最後に

LGTMしていただけると大変励みになりますので参考になった方いましたらよろしくお願いいたします。

11
9
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?