1
0

Binary Indexed Treeの区間合計の定数倍改善手法の紹介

Last updated at Posted at 2024-09-03

はじめに

はじめましての人ははじめまして。alumiと申します。
少し前に書いた「Binary Indexed Treeの一点参照の定数倍改善手法の紹介」という記事で、一点取得の高速化に関する考察をしましたが、区間合計でも同様に定数倍高速化ができ、しかも前述の一点取得は区間合計手法の特殊な場合にまとめられることに気づいたので、解説と実装結果を紹介します。

結論

  • この実装をすると、prefix_sumの差による区間合計の実装に比べて定数倍改善できます。
  • 一点取得をrange_sum(i, i + 1)で実装しても、前回紹介した手法と同程度に高速です。
  • ほぼ同じ考え方で区間加算・区間合計BITについても区間合計・一点取得を定数倍改善できます。

実装と手法の解説

早速実装を見ましょう。
注意点として、一般的な実装と同様に逆元が定義されている必要があります(参考:BIT の定数倍高速化について - ecasdqina-cp

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res = 0
        while l < r:
            res += self.tree[r]
            r -= r & -r
        while r < l:
            res -= self.tree[l]
            l -= l & -l
        return res

一般的なprefix_sumの差による区間合計では、index0に近い側のblockが共通となり、同じ計算を繰り返す場合があります(下図)。ただ図を見ればわかる通り、ある区間$[l,r)$の区間合計に必要なblockは、その区間をすべて被覆するようなblockのうちlsbが最も小さいものの下にあるblockだけです。それより上にあるblockは計算に不要です(下図$[9,14)$ならtree index 12, 14が相当。具体的な証明は上リンクを参照ください)。

range_sum01.png

そこで$l$を直上で被覆するようなblockとなるまで、$r$からlsbを引きまくりその間の区間を足し、引いた時点での$r$を直上で被覆するようなblockとなるまで、$l$からlsbを引きまくりその間の区間を引くと、ほしい区間合計が得られます。

一点取得高速化手法との関係

前回の記事で紹介した一点取得高速化手法の実装は以下のようになります。

    def __getitem__(self, i: int):
        j = i + 1
        k = 1
        res = self.tree[j]
        while j & k == 0:
            res -= self.tree[j - k]
            k <<= 1
        return res

今回の手法において$l=i, r=i+1$とすると、まず最初のwhileは一度しか通らないことがわかります。なぜならlsbは必ず1以上なので、一回lsbを引かれると即座にl < rを満たさなくなるためです。
よって最初のwhileはres = self.tree[i + 1]で置き換えられます。

2つ目のwhileでは$l$が$r$と同じlsbを持つまでlsbを引いています(図を見ると$r$と同じ高さを目指しているのがわかりやすいです)。これは一点取得手法におけるwhile j & k == 0, k <<= 1 の実装と同値です($j = i+1 = r$としてwhile r & k == 0と見れば$r$のlsbが見つかるまで走査していることがわかる)。

以上から前回の記事で紹介した一点取得高速化手法は、今回の区間加算における特殊な場合として見ることができます。よって区間加算だけ今回の手法で実装してしまえば、一点取得については一般的なrange_sum(i, i + 1)のままでも前回手法と同程度に高速化できます。うれしい。

実装
class BinaryIndexedTree:
    __slots__ = ("n", "tree")

    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def add(self, i: int, x) -> None:
        """Add x at i (0-indexed)"""
        i += 1
        while i <= self.n:
            self.tree[i] += x
            i += i & -i

    def update(self, i: int, x):
        """Update i to x (0-indexed)"""
        self.add(i, x - self[i])

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res = 0
        while l < r:
            res += self.tree[r]
            r -= r & -r
        while r < l:
            res -= self.tree[l]
            l -= l & -l
        return res

    def __getitem__(self, i: int):
        return self.range_sum(i, i + 1)

    def __setitem__(self, i: int, val):
        self.update(i, val)

    def __iter__(self):
        yield from (self[i] for i in range(self.n))

    def __str__(self) -> str:
        return f"{self.__class__.__name__}({list(self)})"

区間加算・区間合計BITへの適用

前回同様に区間加算・区間合計BITへ適用できないか考えてみます。
一般的な実装では、range_sumは次のような実装がされています。

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l)

また区間加算・区間合計BITにおけるprefix_sumは次のような実装がされています。

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

ここで_sum_partialは一点取得BITにおけるprefix_sumとほぼ同様の実装で、第一引数に応じて計算で用いるtreeを切り替えている以外に違いはありません。

    def _sum_partial(self, p: int, i: int):
        res = 0
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

よって式変形で上手いこと_sum_partial(p, r) - _sum_partial(p, l)の形をつくってやれば、この部分を高速化したrenge_sumに置き換えることで高速化できそうです。

区間加算・区間合計BITのrange_sumは次のように式変形できます。

\begin{eqnarray*}
\text{range_sum}(l,\ r) & = & \text{prefix_sum}(r)-\text{prefix_sum}(l) \\\\
& = & \text{sum_partial}(0,\ r) + \text{sum_partial}(1,\ r) * r \\
& - & \text{sum_partial}(0,\ l) - \text{sum_partial}(1,\ l) * l \\\\
& = & \text{sum_partial}(0,\ r) - \text{sum_partial}(0,\ l) \\
& + & \text{sum_partial}(1,\ r) * r - \text{sum_partial}(1,\ l) * l \\\\
& = & \text{sum_partial}(0,\ r) - \text{sum_partial}(0,\ l) \\
& + & \text{sum_partial}(1,\ r) * l - \text{sum_partial}(1,\ l) * l \\
& + & \text{sum_partial}(1,\ r) * (r - l) \\\\
& = & \underline{\text{sum_partial}(0,\ r) - \text{sum_partial}(0,\ l)} \\
& + & \big\lbrace\underline{\text{sum_partial}(1,\ r) - \text{sum_partial}(1,\ l)}\big\rbrace * l \\
& + & \text{sum_partial}(1,\ r) * (r - l)
\end{eqnarray*}

下線部でほしい形が作れていますね。ちなみに$\text{sum_partial}(1,\ r) - \text{sum_partial}(1,\ l)$の形を作るために$\text{sum_partial}(1,\ r) * r = \text{sum_partial}(1,\ r) * l + \text{sum_partial}(1,\ r) * (r-l)$と分離しています。
この部分を一点加算・区間合計BITと同様のrange_sum(下の_range_sum_partial)で置き換えてやればいいです。

    def _range_sum_partial(self, p: int, l: int, r: int):
        res = 0
        while l < r:
            res += self.tree[p][r]
            r -= r & -r
        while r < l:
            res -= self.tree[p][l]
            l -= l & -l
        return res

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res0 = self._range_sum_partial(0, l, r)
        res1 = self._range_sum_partial(1, l, r)
        return res0 + res1 * l + self._sum_partial(1, r) * (r - l)

このままでもいいのですが、res0, res1_range_sum_partialで同じ区間$[l,r)$について計算しているのでまとめられそうですね。
更に_range_sum_partialはどうせここでしか使わないのでrange_sum内にぶち込んでしまいましょう。注意点としては最後にかけるl, rres0, res1の計算で使うl, rを区別する必要があるので、tmp_l, tmp_rを用意しています。

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        while tmp_r < tmp_l:
            res0 -= self.tree[0][tmp_l]
            res1 -= self.tree[1][tmp_l]
            tmp_l -= tmp_l & -tmp_l
        return res0 + res1 * l + self._sum_partial(1, r) * (r - l)

ちなみに前回記事の区間加算・区間合計BITの一点取得高速化の提案は次のような感じでした。

    def __getitem__(self, i: int):
        j = i + 1
        k = 1
        res0 = self.tree[0][j]
        res1 = self.tree[1][j]
        while j & k == 0:
            res0 -= self.tree[0][j - k]
            res1 -= self.tree[1][j - k]
            k <<= 1
        return res0 + res1 * i + self._sum_partial(1, j)

よく似ていますね。実際$l,\ r = i,\ i+1$とすれば一点取得高速化手法との関係で述べた通り同値であることが示せます。区間加算・区間合計BITにおいても前回の記事の内容は今回の特殊系であることがわかりますね。

これで完了でもいいのですが、実は_sum_partial(1, r)の部分がまだまとめられます。
_sum_partial(1, r)では、$r$が$0$になるまでlsbを引きながらresに繰り返す足す計算をします。
range_sumres1を見ると、$r$が$l$を下回るまでlsbを引きながらres1に繰り返す足す計算をしています。つまり_sum_partial(1, r)の計算を途中まで行ってくれているということですね。

    def _sum_partial(self, p: int, i: int):
        res = 0
        # こっちはr <= 0まで
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        # こっちはr <= lまで
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r] 
            tmp_r -= tmp_r & -tmp_r
		...

なのでrange_sumにおいて1つ目のwhileから出た時点のres1res2として取っておき、最後に残った$r$を$0$になるまでlsbを引きながらres2に足していけば、res2_sum_partial(1, r)と同じ値になります。
以上をまとめると最終形はこんな感じになります。あんまりきれいではないですね。

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        res2 = res1
        while tmp_r < tmp_l:
            res0 -= self.tree[0][tmp_l]
            res1 -= self.tree[1][tmp_l]
            tmp_l -= tmp_l & -tmp_l
        while tmp_r > 0:
            res2 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        return res0 + res1 * l + res2 * (r - l)

ちなみにprefix_sumも同様のまとめ方ができます。つまり

    def prefix_sum(self, i: int):
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

としているものを

    def prefix_sum(self, i: int):
        res0 = res1 = 0
        tmp_i = i
        while tmp_i > 0:
            res0 += self.tree[0][tmp_i]
            res1 += self.tree[1][tmp_i]
            tmp_i -= tmp_i & -tmp_i
        return res0 + res1 * i

とまとめることができます。こうすると_sum_partialは他のメソッドで使われないのでまるっと消せます1。一応これについても速度チェックしておきましょうか。

速度チェック

一点加算・区間合計BIT|区間合計

ランダムな配列Aにより初期化したBITインスタンスを生成し、ランダムな区間合計クエリに対してかかる時間を比較します。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def range_sum_prefix_diff(self, l: int, r: int):
        return self.prefix_sum(r) - self.prefix_sum(l)

    def range_sum_two_while(self, l: int, r: int):
        res = 0
        while l < r:
            res += self.tree[r]
            r -= r & -r
        while r < l:
            res -= self.tree[l]
            l -= l & -l
        return res

import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)

# ランダムなクエリを生成し、test_caseに格納
test_case = [(random.randrange(N), random.randrange(N)) for _ in range(Q)]
test_case = [(min(l, r), max(l, r)) for l, r in test_case]


# それぞれのメソッドをtest_caseで速度チェック
def test_range_sum_prefix_diff():
    for l, r in test_case:
        bit.range_sum_prefix_diff(l, r)


def test_range_sum_two_while():
    for l, r in test_case:
        bit.range_sum_two_while(l, r)

# loop回の平均を取る
loop = 100
time_prefix_diff = timeit.timeit(test_range_sum_prefix_diff, number=loop)
time_two_while = timeit.timeit(test_range_sum_two_while, number=loop)

# 結果を棒グラフで表示
methods = ["prefix_diff", "two_while"]
times = [time_prefix_diff / loop, time_two_while / loop]
y_positions = [0, 0.6]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Range Sum Comparison")
plt.show()

結果

result_1.png

わずかですがprefix_diff(一般的な実装)よりtwo_while(今回の実装)が速い結果が出ています。

一点加算・区間合計BIT|一点取得

ランダムな一点取得に対してかかる時間を比較します。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def prefix_sum(self, i: int):
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def range_sum(self, l: int, r: int):
        res = 0
        while l < r:
            res += self.tree[r]
            r -= r & -r
        while r < l:
            res -= self.tree[l]
            l -= l & -l
        return res

    def getitem_two_while(self, i: int):
        return self.range_sum(i, i + 1)

    def getitem_prefix_diff(self, i: int):
        return self.prefix_sum(i + 1) - self.prefix_sum(i)

    def getitem_one_while(self, i: int):
        j = i + 1
        k = 1
        res = self.tree[j]
        while j & k == 0:
            res -= self.tree[j - k]
            k <<= 1
        return res


import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)

# ランダムなクエリを生成し、test_caseに格納
test_case = [random.randrange(N) for _ in range(Q)]


# それぞれのメソッドをtest_caseで速度チェック
def test_getitem_prefix_diff():
    for i in test_case:
        bit.getitem_prefix_diff(i)


def test_getitem_two_while():
    for i in test_case:
        bit.getitem_two_while(i)


def test_getitem_one_while():
    for i in test_case:
        bit.getitem_one_while(i)

# loop回の平均を取る
loop = 100
time_prefix_diff = timeit.timeit(test_getitem_prefix_diff, number=loop)
time_two_while = timeit.timeit(test_getitem_two_while, number=loop)
time_one_while = timeit.timeit(test_getitem_one_while, number=loop)


# 結果を棒グラフで表示
methods = ["prefix_diff", "two_while", "one_while"]
times = [time_prefix_diff / loop, time_two_while / loop, time_one_while / loop]
y_positions = [0, 0.6, 1.2]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green", "red"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Get Item Comparison")
plt.show()

結果

result_2.png

prefix_diff(一般的な実装)よりもそれ以外が相当早くなっていますね。one_while(前回の実装)のほうがtwo_while(今回の実装)よりもわずかに早いです。実装を簡潔にする以上に高速化のニーズが高いならどっちも実装すると良いでしょうけど、この程度の差であればtwo_whileで実装をまとめたほうが嬉しみがありそうですね。

区間加算・区間合計BIT|区間合計

ランダムな配列Aにより初期化したBITインスタンスを生成し、ランダムな区間加算クエリを与えてから、ランダムな区間合計クエリに対してかかる時間を比較します。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[0][1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[0][j] += ins.tree[0][i]
        return ins

    def _sum_partial(self, p: int, i: int):
        res = 0
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

    def _add_partial(self, p: int, i: int, x) -> None:
        while i <= self.n:
            self.tree[p][i] += x
            i += i & -i

    def add_range(self, l: int, r: int, x) -> None:
        """Add x at [l, r) (0-indexed)"""
        l += 1
        self._add_partial(0, l, -x * (l - 1))
        self._add_partial(0, r, x * r)
        self._add_partial(1, l, x)
        self._add_partial(1, r, -x)

    def range_sum_prefix_diff(self, l: int, r: int):
        return self.prefix_sum(r) - self.prefix_sum(l)

    def _range_sum_partial(self, p: int, l: int, r: int):
        res = 0
        while l < r:
            res += self.tree[p][r]
            r -= r & -r
        while r < l:
            res -= self.tree[p][l]
            l -= l & -l
        return res

    def range_sum_use_rsp(self, l: int, r: int):
        res0 = self._range_sum_partial(0, l, r)
        res1 = self._range_sum_partial(1, l, r)
        return res0 + res1 * l + self._sum_partial(1, r) * (r - l)

    def range_sum_three_while(self, l: int, r: int):
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        res2 = res1
        while tmp_r < tmp_l:
            res0 -= self.tree[0][tmp_l]
            res1 -= self.tree[1][tmp_l]
            tmp_l -= tmp_l & -tmp_l
        while tmp_r > 0:
            res2 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        return res0 + res1 * l + res2 * (r - l)


import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)
# tree1側にも値を格納
for _ in range(1000):
    l = random.randrange(N + 1)
    r = random.randrange(N + 1)
    if l == r:
        continue
    l, r = min(l, r), max(l, r)
    x = random.randint(-10**9, 10**9)
    bit.add_range(l, r, x)


# ランダムなクエリを生成し、test_caseに格納
test_case = [(random.randrange(N), random.randrange(N)) for _ in range(Q)]
test_case = [(min(l, r), max(l, r)) for l, r in test_case]


# それぞれのメソッドをtest_caseで速度チェック
def test_range_sum_prefix_diff():
    for l, r in test_case:
        bit.range_sum_prefix_diff(l, r)


def test_range_sum_use_rsp():
    for l, r in test_case:
        bit.range_sum_use_rsp(l, r)


def test_range_sum_three_while():
    for l, r in test_case:
        bit.range_sum_three_while(l, r)


# loop回の平均を取る
loop = 100
time_prefix_diff = timeit.timeit(test_range_sum_prefix_diff, number=loop)
time_use_rsp = timeit.timeit(test_range_sum_use_rsp, number=loop)
time_three_while = timeit.timeit(test_range_sum_three_while, number=loop)


# 結果を棒グラフで表示
methods = ["prefix_diff", "use_rsp", "three_while"]
times = [time_prefix_diff / loop, time_use_rsp / loop, time_three_while / loop]
y_positions = [0, 0.6, 1.2]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["red", "blue", "green"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Range Sum Comparison")
plt.show()

結果

result_3.png

ちゃんとthree_while(今回の実装)が最速になってますね。ただ意外にもuse_rsp_range_sum_partialを用意して実装した実装)がprefix_diff(一般的な実装)に劣る結果になってますね。へんなの。

区間加算・区間合計BIT|一点取得

ランダムな一点取得に対してかかる時間を比較します。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[0][1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[0][j] += ins.tree[0][i]
        return ins

    def _sum_partial(self, p: int, i: int):
        res = 0
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

    def prefix_sum(self, i: int):
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

    def _add_partial(self, p: int, i: int, x) -> None:
        while i <= self.n:
            self.tree[p][i] += x
            i += i & -i

    def add_range(self, l: int, r: int, x) -> None:
        """Add x at [l, r) (0-indexed)"""
        l += 1
        self._add_partial(0, l, -x * (l - 1))
        self._add_partial(0, r, x * r)
        self._add_partial(1, l, x)
        self._add_partial(1, r, -x)

    def range_sum_prefix_diff(self, l: int, r: int):
        return self.prefix_sum(r) - self.prefix_sum(l)

    def range_sum_three_while(self, l: int, r: int):
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        res2 = res1
        while tmp_r < tmp_l:
            res0 -= self.tree[0][tmp_l]
            res1 -= self.tree[1][tmp_l]
            tmp_l -= tmp_l & -tmp_l
        while tmp_r > 0:
            res2 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        return res0 + res1 * l + res2 * (r - l)

    def getitem_one_while(self, i: int):
        j = i + 1
        k = 1
        res0 = self.tree[0][j]
        res1 = self.tree[1][j]
        while j & k == 0:
            res0 -= self.tree[0][j - k]
            res1 -= self.tree[1][j - k]
            k <<= 1
        return res0 + res1 * i + self._sum_partial(1, j)

    def getitem_three_while(self, i):
        return self.range_sum_three_while(i, i + 1)

    def getitem_prefix_diff(self, i):
        return self.range_sum_prefix_diff(i, i + 1)


import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)
# tree1側にも値を格納
for _ in range(1000):
    l = random.randrange(N + 1)
    r = random.randrange(N + 1)
    if l == r:
        continue
    l, r = min(l, r), max(l, r)
    x = random.randint(-10**9, 10**9)
    bit.add_range(l, r, x)


# ランダムなクエリを生成し、test_caseに格納
test_case = [random.randrange(N) for _ in range(Q)]


# それぞれのメソッドをtest_caseで速度チェック
def test_getitem_prefix_diff():
    for i in test_case:
        bit.getitem_prefix_diff(i)


def test_getitem_two_while():
    for i in test_case:
        bit.getitem_three_while(i)


def test_getitem_one_while():
    for i in test_case:
        bit.getitem_one_while(i)

# loop回の平均を取る
loop = 100
time_prefix_diff = timeit.timeit(test_getitem_prefix_diff, number=loop)
time_three_while = timeit.timeit(test_getitem_two_while, number=loop)
time_one_while = timeit.timeit(test_getitem_one_while, number=loop)


# 結果を棒グラフで表示
methods = ["prefix_diff", "three_while", "one_while"]
times = [time_prefix_diff / loop, time_three_while / loop, time_one_while / loop]
y_positions = [0, 0.6, 1.2]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green", "red"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Get Item Comparison")
plt.show()

結果

result_4.png

ほぼ一点加算・区間合計BITのときの結果と同じですね。2倍以上早くなっています。強いて違いを言えばone_whilethree_whileとの差がほぼなくなっている点でしょうか。この結果からthree_whileで実装をまとめても問題なさそうなことがわかりますね。

区間加算・区間合計BIT|prefix_sum

ランダムな$i$に対するprefix_sumの計算にかかる時間を比較します。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[0][1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[0][j] += ins.tree[0][i]
        return ins

    def _sum_partial(self, p: int, i: int):
        res = 0
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

    def prefix_sum_use_sp(self, i: int):
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

    def prefix_sum_one_while(self, i: int):
        res0 = res1 = 0
        tmp_i = i
        while tmp_i > 0:
            res0 += self.tree[0][tmp_i]
            res1 += self.tree[1][tmp_i]
            tmp_i -= tmp_i & -tmp_i
        return res0 + res1 * i

    def _add_partial(self, p: int, i: int, x) -> None:
        while i <= self.n:
            self.tree[p][i] += x
            i += i & -i

    def add_range(self, l: int, r: int, x) -> None:
        """Add x at [l, r) (0-indexed)"""
        l += 1
        self._add_partial(0, l, -x * (l - 1))
        self._add_partial(0, r, x * r)
        self._add_partial(1, l, x)
        self._add_partial(1, r, -x)

import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)
# tree1側にも値を格納
for _ in range(1000):
    l = random.randrange(N + 1)
    r = random.randrange(N + 1)
    if l == r:
        continue
    l, r = min(l, r), max(l, r)
    x = random.randint(-10**9, 10**9)
    bit.add_range(l, r, x)


# ランダムなクエリを生成し、test_caseに格納
test_case = [random.randrange(N) for _ in range(Q)]


# それぞれのメソッドをtest_caseで速度チェック
def test_prefix_sum_use_sp():
    for i in test_case:
        bit.prefix_sum_use_sp(i)


def test_prefix_sum_one_while():
    for i in test_case:
        bit.prefix_sum_one_while(i)


# loop回の平均を取る
loop = 100
time_prefix_sum_use_sp = timeit.timeit(test_prefix_sum_use_sp, number=loop)
time_prefix_sum_one_while = timeit.timeit(test_prefix_sum_one_while, number=loop)


# 結果を棒グラフで表示
methods = ["use_sp", "one_while"]
times = [time_prefix_sum_use_sp / loop, time_prefix_sum_one_while / loop]
y_positions = [0, 0.6]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Prefix Sum Comparison")
plt.show()

結果

result_5.png

sum_partialを使う実装(use_sp)に比べて結構改善されていますね。内部メソッドに依存しない形にできますし個人的には好みです。

おまけ:区間加算・区間合計BITの改善全部載せ実装
class BinaryIndexedTree:
    __slots__ = ("n", "tree")

    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[0][1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[0][j] += ins.tree[0][i]
        return ins

    def prefix_sum(self, i: int):
        res0 = res1 = 0
        tmp_i = i
        while tmp_i > 0:
            res0 += self.tree[0][tmp_i]
            res1 += self.tree[1][tmp_i]
            tmp_i -= tmp_i & -tmp_i
        return res0 + res1 * i

    def add_range(self, l: int, r: int, x) -> None:
        l += 1
        tmp_l, tmp_r = l, r
        while tmp_l <= self.n:
            self.tree[0][tmp_l] -= x * (l - 1)
            self.tree[1][tmp_l] += x
            tmp_l += tmp_l & -tmp_l
        while tmp_r <= self.n:
            self.tree[0][tmp_r] += x * r
            self.tree[1][tmp_r] -= x
            tmp_r += tmp_r & -tmp_r

    def add_point(self, i: int, x) -> None:
        """Add x at i (0-indexed)"""
        i += 1
        while i <= self.n:
            self.tree[0][i] += x
            i += i & -i

    def update_point(self, i: int, x) -> None:
        """Update i to x (0-indexed)"""
        self.add_point(i, x - self[i])

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        res0 = res1 = 0
        tmp_l, tmp_r = l, r
        while tmp_l < tmp_r:
            res0 += self.tree[0][tmp_r]
            res1 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        res2 = res1
        while tmp_r < tmp_l:
            res0 -= self.tree[0][tmp_l]
            res1 -= self.tree[1][tmp_l]
            tmp_l -= tmp_l & -tmp_l
        while tmp_r > 0:
            res2 += self.tree[1][tmp_r]
            tmp_r -= tmp_r & -tmp_r
        return res0 + res1 * l + res2 * (r - l)

    def __getitem__(self, i: int):
        return self.range_sum(i, i + 1)

    def __iter__(self):
        yield from (self[i] for i in range(self.n))

    def __str__(self) -> str:
        return f"{self.__class__.__name__}({list(self)})"

結び

というわけでrange_sumの高速化の紹介と区間加算・区間合計BITへの適用の提案でした。実用上はどの実装でも困ることはないとは思いますが、こうした考えを通してデータ構造と仲良くなっておくと咄嗟の拡張にも対応できそうですね(願望)。
前回同様2Dへの拡張は未来の自分への宿題にします(この記事を読んでくださった方への遺題継承ともしたいです)。

この記事が参考になりましたら嬉しいです。よかったら左上のいいね押してね。

  1. この形式のprefix_sumは今回の手法のrange_sumの特殊系となるので、range_sum(0, i)と実装してもいい

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0