LoginSignup
1
1

PyPyでの区間取得系データ構造の簡単な速度比較(競プロ文脈)

Last updated at Posted at 2023-06-02

こんにちは。株式会社オプティマインドの伊豆原と申します。当社の最適化チームに所属しており、また社内部活として競プロ部に所属してます。

競プロでは配列データ$[a_0,\cdots,a_{N-1}]$と2項演算$\oplus$が与えられた時に、連続部分列$[a_l,\cdots,a_{r-1}]$に対する$a_l \oplus a_{l+1} \oplus \cdots \oplus a_{r-1}$を高速に求めたいことがよくあります。こういった時は計算量$\mathcal{O}(\log N)$で求められるSegment TreeやBinary Indexed Tree、データの更新がなければSparse Table(こちらは$\mathcal{O}(1)$で求めれます)などがよく使われ、また一般に(ニーズが合えば)Segment TreeよりもBinary Indexed TreeやSparse Tableのほうが"高速"ともよく言われます。

今回はこの"高速"と呼ばれる件について、実際に速度を比較してどのくらい速いのかを確認する簡単な検証記事になります。使用言語は個人的に競プロで使用しているPython(実行はPyPy)を使用します。

Segment Tree

Segment Treeはモノイド則を満たす二項演算についてその区間取得get(l,r)と一点更新add(i,val)を配列長$N$に対して$\mathcal{O}(\log N)$で行えるデータ構造になります。

今回は下のように実装したものを使います。

class SegmentTree:
    def __init__(self,data,op,default):
        N = len(data)
        self.op = op
        self.default = default
        self.l = 2**((N-1).bit_length())
        self.data = [default]*self.l + data + [default]*(self.l-N)
        for i in range(self.l-1,0,-1):
            self.data[i] = op(self.data[2*i], self.data[2*i+1])
    def add(self,i,val):
        i += self.l
        self.data[i] = self.op(self.data[i], val)
        i = i//2
        while i > 0:
            self.data[i] = self.op(self.data[2*i], self.data[2*i+1])
            i = i//2
    def get(self,i,j):
        i += self.l
        j += self.l
        s = self.default 
        while j-i > 0:
            if i & 1:
                s = self.op(s,self.data[i])
                i += 1
            if j & 1:
                s = self.op(s,self.data[j-1])
                j -= 1
            i, j = i//2, j//2
        return s

Binary Indexed Tree

Binary Indexed Treeはモノイド則を満たす二項演算についてその$[0,i)$の形の区間取得get(i)と一点更新add(i,val)を配列長$N$に対して$\mathcal{O}(\log N)$で行えるデータ構造になります。仮に一般の区間$[l,r)$での値を取得したい場合、これは二項演算が可逆($\Leftrightarrow \forall a, \exists b(=-a),$ s.t. $a \oplus b=0$)ならば可能でして、get(r) $\oplus$ (- get(l))として計算することができます。

今回は下のように実装したものを使います。

class BinaryIndexedTree:
    def __init__(self,data,op,default):
        self.op, self.default, N = op, default, len(data)
        self.l = 2**((N-1).bit_length())
        self.data = [default]*(self.l+1)
        for i in range(1,N+1):
            v = data[i-1]
            while i <= self.l:
                self.data[i], i = self.op(self.data[i], v), i+(i&-i)
    def add(self,i,val):
        i += 1
        while i <= self.l:
            self.data[i],i = self.op(self.data[i], val),i+(i&-i)
    def get(self,i): # [0,i)
        res = self.default
        while i > 0:
            res,i = self.op(res, self.data[i]),i-(i&-i)
        return res

Sparse Table

Sparse Tableは結合法則と冪等性を満たす二項演算についてその区間取得get(l,r)を配列長$N$に対して$\mathcal{O}(1)$で行えるデータ構造になります(前処理に$\mathcal{O}(N\log N)$掛かります)。冪等性が必要なので足し算や掛け算は適用できず、maxやminが典型的な例になります。

今回は下のように実装したものを使います。

class SparseTable:
    def __init__(self,data,op):
        N = len(data)
        self.N = N
        self.op = op
        self.nrows = N.bit_length()
        self.pow = [2**p for p in range(self.nrows)]
        self.table = [0]*(self.N*self.nrows)
        for i in range(N):
            self.table[i] = data[i]
        step = 1
        for row in range(1, self.nrows):
            l,r = N*(row-1), N*row-1
            for i in range(N):
                self.table[l+N+i] = op(self.table[l+i], self.table[min(l+i+step,r)])
            step *= 2
    def get(self,l,r):
        p = (r-l).bit_length()-1
        row = self.N * p
        w = self.pow[p] # 2**p
        return self.op(self.table[row+l],self.table[row+r-w])

Wavelet Matrix

Wavelet Matrtixはこれまでに紹介したデータ構造と違い演算の種類は限られますが、整数列に対する幅広い操作を$\log N$の計算量で可能にします。今回はmax操作の代わりとしてquantile(l,r,i)([l,r)の中でi番目に大きい値を取得する)のみ実装します。またWavelet Matrixの実装では簡潔ビットベクトルというものがよく使われるようですが、(面倒なので)今回はそれっぽいビットベクトルを素直に実装します。

実装には下記のサイトを参考にさせて頂きました。

ウェーブレット行列(wavelet matrix) - Eating Your Own Cat Food

なお計算量は大きくなりますが、データの更新を可能にした動的Wavelet Matrixというものもあるようです。

import array as ar
from itertools import accumulate

class BitVector:
    def __init__(self, data):
        self.n = len(data)
        self.data = ar.array('b',data)
        self.cumsum = ar.array('i',[0] + list(accumulate(data)))
    def access(self, i):
        return self.data[i]
    def rank(self, i, x):
        return self.cumsum[i] if x else i - self.cumsum[i]
    def select(self, i, x):
        l,r = 0, self.n
        while r-l > 1:
            m = (r+l)//2
            n = self.cumsum[m] if x else m - self.cumsum[m]
            if n < i+1:
                l = m
            else:
                r = m
        return l
class WaveletMatrix:
    def __init__(self, data):
        self.T = data[:]
        self.N = len(data)
        max_value = max(data)
        self.l = max_value.bit_length()-1
        th = 1 << self.l
        self.B = []
        self.cnt0 = []
        while th:
            self.B.append(BitVector(ar.array('b',((t&th)==th for t in self.T))))
            L,R = [],[]
            for t,b in zip(self.T, self.B[-1].data):
                if b:
                    R.append(t)
                else:
                    L.append(t)
            self.cnt0.append(len(L))
            self.T = L + R
            th >>= 1
    # find i-th value in [l,r)
    def quantile(self, l, r, i):
        th = 1 << self.l
        for c0, row in zip(self.cnt0,self.B):
            n0 = row.rank(r,0) - row.rank(l,0)
            n1 = row.rank(r,1) - row.rank(l,1)
            if n0 <= i:
                i -= n0
                l = c0 + row.rank(l,1)
                r = c0 + row.rank(r,1)
            else:
                l = row.rank(l,0)
                r = row.rank(r,0)
            th >>= 1
        return self.T[r-1]

Segment Tree vs Binary Indexed Tree

では速度比較に入ります。計測環境はMacBookProでプロセッサは2GHz クアッドコアIntel Core i5。メモリは16GBです。動作にはPyPy 7.3.9を使用します。

まずはSegment TreeとBinary Indexed Treeについて足し算でのsetgetの比較をします。Binary Indexed Treeに関しては上述の理由からgetは2回行う形になります。なおPyPyはJITなので、計測値は以下のコードを10回程度廻した上での10回目の結果を使います。

from time import perf_counter as time
from operator import add
import random

N = 2*10**5 # or 10**6
Q = 5*10**5
A = [random.randint(0,10**9) for _ in range(N)]

add_queries = [(random.randint(0,N-1),random.randint(0,10**9)) for _ in range(Q)]
get_queries = [sorted(random.sample(range(N),2)) for _ in range(Q)]

seg = SegmentTree(A, add, 0)
bit = BinaryIndexedTree(A, add, 0)

start = time()
for i,x in add_queries:
    seg.add(i,x)
print(f"seg_add: {time()-start}") 

start = time()
for l,r in get_queries:
    seg.get(l,r)
print(f"seg_get: {time()-start}") 

start = time()
for i,x in add_queries:
    bit.add(i,x)
print(f"bit_add: {time()-start}") 

start = time()
for l,r in get_queries:
    bit.get(r)-bit.get(l)
print(f"bit_get: {time()-start}") 

以下は計測結果です。計算量は同程度のはずですが、$N=2 \times 10^5$においてBinary Indexed TreeはSegment Treeのおよそ50%〜60%の計算時間になってますね。思ってたより定数倍の差があります。

add[s](N=$2\times 10^5$) get[s] (N=$2\times 10^5$) add[s] (N=$10^6$) get[s] (N=$10^6$)
Segment Tree 0.107s 0.192s 0.157s 0.253s
Binary Indexed Tree 0.056s 0.113s 0.109s 0.203s

Segment Tree vs Sparse Table vs Wavelet Matrix

続いてSegment TreeとSparse TableとWavelet Matrixについてmaxでのgetの比較をします。先程と同様に計算を10回程度廻した上での10回目の結果を使います。

seg = SegmentTree(A, max, 0)
spt = SparseTable(A, max)
wm  = WaveletMatrix(A)

start = time()
for l,r in get_queries:
    seg.get(l,r)
print(f"seg_get: {time()-start:0.3f}") 

start = time()
for l,r in get_queries:
    spt.get(l,r)
print(f"spt_get: {time()-start:0.3f}")

start = time()
for l,r in get_queries:
    wm.quantile(l,r,r-l-1)
print(f"wm__get: {time()-start:0.3f}")

以下は計測結果です。Sparse Tableはもともと計算量からして速いものですが、Segment Treeのおよそ15%の計算時間ととても高速に計算できています。Wavelet Matrixもその汎用性の割には十分な速度がでてるように見えます(今回の実装ですと$N=2\times 10^5$での構築に0.7s程度掛かってますが。

get[s] (N=$2\times 10^5$) get[s] (N=$10^6$)
Segment Tree 0.344s 0.362s
Sparse Table 0.044s 0.053s
Wavelet Matrix 0.413s 0.485s

結び

以上、簡単ではございましたが区間取得に関するデータ構造の速度比較を書かせて頂きました。この手のデータ構造には他にもモノイド演算を許容するDisjoint Sparse Tableなどがありますので、また時間がある時に追記などさせて頂ければと思います。

ここまで読んで頂きありがとうございました!

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1