LoginSignup
1
1

More than 1 year has passed since last update.

pythonで永続セグメント木実装

Last updated at Posted at 2021-08-25

はじめに

世の中には、過去の状態にアクセスしたいという需要が存在します。それを高速に実現する一つが永続データ構造です。

参考記事

永続配列について

①永続配列の使い道

永続セグメント木を考える前に永続配列について紹介します。
qiita_p1.png
というようにデータを構築したうえで
→Ver.0の4番目を知りたい!
→Ver.1の2番目を知りたい!
のような命令(クエリ)にこたえるために、愚直にデータ構築をやると配列の長さをN、version数をQとして、NQのメモリとそれを構築する計算量がかかってしまうことが分かります。ですがよく見ると(よく見なくても)変更は一か所ずつなので、高速化ができそうだということも分かります。ここで使えるのが永続配列というデータ構造です。

②永続配列の構造

qiita_p2.png
図は二分木で構築した永続配列(Ver.0)です。二分木の葉の部分にのみ値を持ちほかのノードは値を持ちません(根から該当箇所にアクセスするために値を持たないノードが存在しているといった具合です)。

根からのアクセスの仕方を見ておきましょう。
indexの1番にアクセスしたいときは根からLLRとたどればよく、
indexの6番にアクセスしたいときは根からRRLとたどればいいことが分かります。
Lを0、Rを1とすると、indexの1番は、001→1,indexの6番は、110→6となり、インデックス番号とたどり方が
1対1で対応していることが分かります。アクセスはこの性質をつかえばよさそうです。計算量はO(logN)です。

③更新とクエリの答え方

Ver.0の3番目を3に変えたものをVer.1とする操作をやってみます。これは赤いところのようにノードを追加し、pathをつなげば実現できます(計算量はO(logN)です。
qiita_p3.png
クエリ(Ver.○○のindex○○の値を取得)に対しては、ver.0の時は青色の根から、ver.1の時は赤色の根から始め、末端のノードにアクセスすればしていけばよいです。

例えばVer.1の2番目を知りたい場合には緑の矢印をたどればよいです。

④実装例

pythonでの実装例を紹介します。

Persistent_Array.py
class Node:
    def __init__(self,default):
        self.rch = None
        self.lch = None
        self.val = default

class PersistentArray:
    def __init__(self, ls, default_ver=0):
        '''
        ls (list) : 永続配列にしたい配列
        default_ver (any): 最初のversion(デフォルト:0)
        '''
        N = len(ls)
        self.N = N
        self.K = (N - 1).bit_length()
        self.N2 = 1 << self.K
        self.dat = [Node(None) for i in range(2**(self.K + 1))]
        for i in range(self.N):  # 葉の構築
            self.dat[self.N2 + i].val = ls[i]
        self.build()
        self.verdict = dict()
        self.verdict[default_ver] = 1 # 各versionの根のindexを格納

    def build(self):
        for node in range(self.N2 - 1, 0, -1):
            self.dat[node].rch = self.dat[node<<1 | 1]
            self.dat[node].lch = self.dat[node<<1]
            self.dat.pop()
            self.dat.pop()

    def get_t_x(self, t, x):  # ver.tにおけるリストのx番目の値
        '''
        version:tにおけるindex:xの値を出力(O(logN))
        '''
        x += self.N2
        v = self.dat[self.verdict[t]] # ver.tの根
        path = bin(x)[3:]
        for i in path:
            if i == '0':
                v = v.lch
            else:
                v = v.rch
        return v.val

    def update_told_tnew_x_val(self, told, tnew, x, val):
        '''
        version:toldのindex:xをvalに変更したものをversion:tnewとする(O(logN))
        told: 変更前のversion
        tnew: 変更後のversion
        x: 変更するindex
        val: 変更後の値
        '''
        if not (told in self.verdict):
            raise('No such version exists')
        x += self.N2
        path = bin(x)[3:]
        self.verdict[tnew] = len(self.dat)
        new_nodes = [Node(None) for _ in range(len(path)+1)]
        v_old = self.dat[self.verdict[told]]
        v_new = new_nodes[0]
        now = 1
        for i in path: # ノードをつなげる
            if i == '0':
                v_new.rch = v_old.rch
                v_new.lch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.lch 
            else:
                v_new.lch = v_old.lch
                v_new.rch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.rch
            now += 1
        v_new.val = val
        self.dat.append(new_nodes[0])

    def get_t_all(self,t):
        '''
        version:tにおける配列を出力(O(NlogN))
        return : list
        '''
        if not (t in self.verdict):
            raise('No such version exists')
        ret = []
        for i in range(self.N2,self.N2+self.N):
            path = bin(i)[3:]
            v = self.dat[self.verdict[t]]
            for p in path:
                if p == '0':
                    v = v.lch
                else:
                    v = v.rch
            ret.append(v.val)
        return ret

    def __getitem__(self, xt): return self.get_t_x(xt[0], xt[1])

Nodeクラスに左子(lch)、右子(rch)、値(val)を持たせ、二分木を構築しています。

⑤使用例

use.py
SG = PersistentArray([2,4,6,8,9,1,3,5])
SG.update_told_tnew_x_val(0,1,3,3)
SG.update_told_tnew_x_val(1,2,4,5)
SG.update_told_tnew_x_val(1,3,5,2)
print(SG.get_t_x(0,3))
# 8
print(SG.get_t_x(1,3))
# 3
print(SG.get_t_all(0))
# [2, 4, 6, 8, 9, 1, 3, 5]
print(SG.get_t_all(1))
# [2, 4, 6, 3, 9, 1, 3, 5]
print(SG.get_t_all(2))
# [2, 4, 6, 3, 5, 1, 3, 5]
print(SG.get_t_all(3))
# [2, 4, 6, 3, 9, 2, 3, 5]

⑥完全永続性

図を修正しました(21/8/26)
qiita_p4.png

赤色矢印のように過去のversionから枝分かれするようなversionを作製できるのが完全永続、できないのが部分永続です。今回の永続配列は、過去の状態からの枝分かれも作成できるので完全永続となっています。

永続セグメント木

実は二分木の永続配列が分かれば永続セグメント木はあまり難しくありません。具体的には永続配列の二分木の親ノードに子ノードから求められる値を記録していき、区間クエリ時に読み取るといった操作が加わります。

①永続セグメント木の構造

qiita_p5.png
今回紹介した永続配列は完全二分木なので、そのままセグメント木への拡張ができます。具体にはアクセスするときにpathとして使っていたノードに親の値を持たせます。セグメント木に乗せる関数は加算(sum)としました。具体には図のようになります。

②更新のやり方

qiita_p6.png
先ほどの永続配列とほぼ同様です。異なるところは親の値の更新が入るところです。計算用は変わらずO(logN)でできます。

③区間クエリの取得

qiita_p7.png
根からdfsをしていき、読み取るべきノードまで行ったらストップします。また区間外に出る場合には枝かりをすることで計算量を抑えることができます(もっといいアルゴリズムがあるかもです)。
Ver.1の[1,4)を取得する場合、考えるべき区間は緑の部分です。枝かり+該当箇所でストップすることで緑矢印部分のpath探索で済みます(4+9=13)。(次項の実装のところでは厳密には、末端の枝かりが甘いですが、大した計算量ではないので、場合分けは省いています。)

④実装例

pythonでの実装例を紹介します。

Persistent_SegTree.py
import math
class Node:
    def __init__(self,default):
        self.rch = None
        self.lch = None
        self.val = default

class PersistentSegTree:
    DEFAULT = {
        'min': 1 << 60,
        'max': -(1 << 60),
        'sum': 0,
        'prd': 1,
        'gcd': 0,
        'lmc': 1,
        '^': 0,
        '&': (1 << 60) - 1,
        '|': 0,
    }

    FUNC = {
        'min': min,
        'max': max,
        'sum': (lambda x, y: x + y),
        'prd': (lambda x, y: x * y),
        'gcd': math.gcd,
        'lmc': (lambda x, y: (x * y) // math.gcd(x, y)),
        '^': (lambda x, y: x ^ y),
        '&': (lambda x, y: x & y),
        '|': (lambda x, y: x | y),
    }

    def __init__(self, ls, mode='min', func=None, default=None, default_ver=0):
        '''
        ls (list) : 永続配列にしたい配列
        mode : モード(デフォルト:min)
        func: 関数を指定(デフォルト:None) ※modeを指定する場合は不要
        default: 単位元を指定(デフォルト:None) ※modeを指定する場合は不要
        default_ver (any): 最初のversion(デフォルト:0)
        '''
        N = len(ls)
        if default == None:
            self.default = self.DEFAULT[mode]
        else:
            self.default = default
        if func == None:
            self.func = self.FUNC[mode]
        else:
            self.func = func
        self.N = N
        self.K = (N - 1).bit_length()
        self.N2 = 1 << self.K
        self.dat = [Node(self.default) for _ in range(2**(self.K + 1))]
        for i in range(self.N):  # 葉の構築
            self.dat[self.N2 + i].val = ls[i]
        self.build()
        self.verdict = dict()
        self.verdict[default_ver] = 1 # 各versionの根のindexを格納

    def build(self):
        for node in range(self.N2 - 1, 0, -1):
            self.dat[node].rch = self.dat[node<<1 | 1]
            self.dat[node].lch = self.dat[node<<1]
            self.dat[node].val = self.func(self.dat[node<<1 | 1].val,self.dat[node<<1].val)
            self.dat.pop()
            self.dat.pop()

    def get_t_x(self, t, x):  # ver.tにおけるリストのx番目の値
        '''
        version:tにおけるindex:xの値を出力(O(logN))
        return : int
        '''
        x += self.N2
        v = self.dat[self.verdict[t]] # ver.tの根
        path = bin(x)[3:]
        for i in path:
            if i == '0':
                v = v.lch
            else:
                v = v.rch
        return v.val

    def update_told_tnew_x_val(self, told, tnew, x, val):
        '''
        version:toldのindex:xをvalに変更したものをtnewとする(O(logN))
        told: 変更前のversion
        tnew: 変更後のversion
        x: 変更するindex
        val: 変更後の値
        '''
        if not (told in self.verdict):
            raise('No such version exists')
        x += self.N2
        path = bin(x)[3:]
        self.verdict[tnew] = len(self.dat)
        new_nodes = [Node(self.default) for _ in range(len(path)+1)]
        v_old = self.dat[self.verdict[told]]
        v_new = new_nodes[0]
        now = 1
        for i in path: # ノードをつなげる
            if i == '0':
                v_new.rch = v_old.rch
                v_new.lch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.lch 
            else:
                v_new.lch = v_old.lch
                v_new.rch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.rch
            now += 1
        v_new.val = val
        for node in range(len(path)-1,-1,-1): # 付け加えたノードの値を子から親に更新
            new_nodes[node].val = self.func(new_nodes[node].lch.val,new_nodes[node].rch.val)
        self.dat.append(new_nodes[0])

    def get_t_all(self,t): 
        '''
        version:tにおける配列を出力(O(NlogN))
        return : list
        '''
        if not (t in self.verdict):
            raise('No such version exists')
        ret = []
        for i in range(self.N2,self.N2+self.N):
            path = bin(i)[3:]
            v = self.dat[self.verdict[t]]
            for p in path:
                if p == '0':
                    v = v.lch
                else:
                    v = v.rch
            ret.append(v.val)
        return ret

    def query_t(self, t, L, R):
        '''
        version:tにおける区間クエリ[L,R)(O(logN))
        return : int
        '''
        L += self.N2; R += self.N2
        L0 = L; R0 = R
        ls = set() # 見るべきノード番号
        while L < R:
            if L & 1:
                ls.add(L)
                L += 1
            if R & 1:
                R -= 1
                ls.add(R)
            L >>= 1
            R >>= 1
        valdef = self.default
        d = [(1,0,self.dat[self.verdict[t]])]
        while d: # 二分木上をdfs
            num, depth, node = d.pop()
            if num in ls:
                valdef = self.func(node.val, valdef)
                continue
            if depth == self.K:
                continue
            if num >= (L0 >> (self.K-depth)):
                d.append((num << 1,depth+1, node.lch))
            if R0 > ((num << 1 | 1) << (self.K-depth-1)):
                d.append((num << 1 | 1, depth+1 , node.rch))
        return valdef

    def __getitem__(self, tx): return self.get_t_x(tx[0], tx[1])

⑤使用例

use.py
SG = PersistentSegTree([2,4,6,8,9,1,3,5],mode='sum')
SG.update_told_tnew_x_val(0,1,3,3)
SG.update_told_tnew_x_val(1,2,4,5)
SG.update_told_tnew_x_val(1,3,5,2)
print(SG.get_t_all(1))
# [2, 4, 6, 3, 9, 1, 3, 5]
print(SG.get_t_all(2))
# [2, 4, 6, 3, 5, 1, 3, 5]
print(SG.get_t_all(3))
# [2, 4, 6, 3, 9, 2, 3, 5]
print(SG.query_t(0,0,6))
# 30
print(SG.query_t(1,0,6))
# 25
print(SG.query_t(2,0,6))
# 21
print(SG.query_t(3,0,6))
# 26

⑥競技プログラミングでの使用例

Atcoder ABC165F LIS on Tree 提出結果(TLE) https://atcoder.jp/contests/abc165/submissions/25330532
何回か提出すると2100msくらいのTLEにもなったので、TLが2.5secなら通っていたかもです。なかなか制約が厳しいようです。

yosupo Range Kth Smallest 提出結果AC 3414ms https://judge.yosupo.jp/submission/62383
少し工夫すると区間のk番目の値をlog(N)で求めるアルゴリズムが書けます。

最後に

今回作っては見ましたが、競技プログラミングにおいてpythonで通せる永続セグメント木問題はあまりないのかもしれません。。。そろそろC++を学ばなければいけませんね。。。
LGTMしていただけると大変励みになりますので参考になった方いましたらよろしくお願いいたします。

python を使ったデータ構造の記事は他にも書いてありますのでよかったら見ていってください。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1