1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Binary Indexed Treeの一点参照の定数倍改善手法の紹介

Last updated at Posted at 2024-07-31

結論

はじめに

この記事は続編記事の「Binary Indexed Treeの区間合計の定数倍改善手法の紹介」の一部下位互換です。続編記事の内容を実装したほうが簡潔になります。

はじめましての人ははじめまして。alumiと申します。
趣味の競プロで使うBinary Indexed Tree(a.k.a. Fenwick Tree。以下BIT)のライブラリをいじっていたら、一般的なprefix_sumの差から導出する一点参照の実装を定数倍改善できた1ので、考察の過程と実装結果を紹介します。

改良前の一点参照の実装

以下対象とするBITは、一点加算クエリ・区間和導出クエリに対応し、0-indexedで実装しているものとします(区間加算クエリ・区間和クエリについては後述)。また実装はPythonで行います。

BITでは、区間$[0,\ i)$の和を導出するprefix_sumメソッドを実装し、prefix_sumの差prefix_sum(r) - prefix_sum(l)から区間和$[l,\ r)$を導出するrange_sum(l, r)を実装する形式が一般的です。

range_sumが実装されていれば、ある一点$i$を参照したい場合はrange_sum(i, i + 1)とすれば導出できますし、実際これは$O(\log{N})$で動作するので十分高速です。

bit説明_prefix_sum.png

例えば上図でTree indexが6となっている箇所を一点参照したい場合、1~6までを被覆するブロックの和(濃い紫色)から、1~5までを被覆するブロック(青色)を引けば、1つ分のブロックの値が導出できます。

実装例(改良前)
class BinaryIndexedTree:
    __slot__ = ("n", "tree")

    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def add(self, i: int, x) -> None:
        """Add x at i (0-indexed)"""
        i += 1
        while i <= self.n:
            self.tree[i] += x
            i += i & -i

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l)

    def __getitem__(self, i: int):
        return self.range_sum(i, i + 1)

改良の着想

ただ図を眺めるとわかりますが、6だけがほしい場合は6の担当する区間和ブロックから隣の5だけを引けば求めることができるので、range_sum実装では少し無駄な計算をしていることがわかります。

bit説明_奇数参照.png

上図を一点参照を改善したい気持ちで見ると、図でTree indexとなっている部分が奇数の値は計算するまでもなく直接Treeから参照すればいいことがわかります(それ以外についてはrange_sumで求める)。奇数は2進数で表したとき一番左のbitが立っているので、奇数かどうかはif j & 1 == 1:で判定できます。

    def __getitem__(self, i: int):
+       j = i + 1
+       if j & 1:
+           return self.tree[i + 1]
        return self.range_sum(i, i + 1)

これだけで全体の50%を$O(\log{N})$から$O(1)$に改善できるので、実用上はこれでも十分だと思います。

もう少し欲張って改善しようとすると、次に導出が簡単そうなのは上図でlsb(最下位ビット)が**1*となっている箇所(下から2段目、2, 6, 10)は、自分のすぐ左隣りのブロックを引くだけでひとつ分の値が計算できそうです。

bit説明_lsb2.png

実際上図のようにlsbが**1*の段にある6は、自分が担当するブロックからすぐ左隣りの5を引けば良さそうです。
lsbが**1*かどうかは、if j & 1 == 1:がFalseとなった後ならif j & 2 == 1かどうかで判定できるので、Trueなら自分の値から一つ前の値を引いた値を返せばいいです。

    def __getitem__(self, i: int):
        j = i + 1
        if j & 1: # lsb = 1
            return self.tree[j]
+       if j & 2: # lsb = 2
+           return self.tree[j] - self.tree[j - 1]
        return self.range_sum(i, i + 1)

同様にlsbが*1**となっている箇所を一点参照したい場合は、自分自身から一つ前と二つ前を引けばいいことが図を眺めるとわかります。
lsbが*1**かどうかは、j & 4 == 1かどうかで判定できます。

bit説明_lsb3.png

    def __getitem__(self, i: int):
        j = i + 1
        if j & 1: # lsb = 1
            return self.tree[j]
        if j & 2: # lsb = 2
            return self.tree[j] - self.tree[j - 1]
+       if j & 4: # lsb = 4
+           return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
        return self.range_sum(i, i + 1)

ここまでで全体の87.5%(50%+25%+12.5%)を$O(1)$で導出することができるようになりました。
これまでを簡単にまとめると、

  • lsbが***1self.tree[j]
  • lsbが**1*self.tree[j] - self.tree[j - 1]
  • lsbが*1**self.tree[j] - self.tree[j - 1] - self.tree[j - 2]

となっており、引く値のindexが一つずつ下がっていっています。
ではlsbが1***では自分から三つ目までを引けばいいのでしょうか。実は一つずつの法則はここまでで、これ以降は引く値が2べきで増えていきます(実はこれまでもそう)。

bit説明_lsb4.png

上図からわかる通り、自分が担当するブロックからそれより前の値を順番に引いていくのは同じですが、三つ目の値は自分のindexから4つ前になります(上図のように8なら4)。
lsbが1***かどうかは、j & 8 == 1かどうかで判定できます。

    def __getitem__(self, i: int):
        j = i + 1
        if j & 1: # lsb = 1
            return self.tree[j]
        if j & 2: # lsb = 2
            return self.tree[j] - self.tree[j - 1]
        if j & 4: # lsb = 4
            return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
+       if j & 8: # lsb = 8
+           return self.tree[j] - self.tree[j - 1] - self.tree[j - 2] - self.tree[j - 4]
        return self.range_sum(i, i + 1)

一般化

ここまでくればなんとなく法則性が見えてきて、任意の$j$に対して一般化できそうです。つまり参照したいTree indexを$j \ ( = i+1$)とすると、自分が担当するブロックのlsbと引く値のlsbが同じ高さになるまで、自分から2べきだけ離れているブロックを引き続けるという感じで一般化できそうです。
これは実装を見ていただいたほうがわかりやすいと思います。

    def __getitem__(self, i: int):
        j = i + 1
        k = 1
        res = self.tree[j]
        while j & k == 0:
            res -= self.tree[j - k]
            k <<= 1
        return res

res = self.tree[j]として自分が担当するブロックで初期化し、j & k == 0の間、つまり自分が担当するブロックのlsbをkが超えるまでself.tree[j - k]をresから引きまくる感じです。

速度テスト

実際どの程度早くなったのかテストしてみましょう。
テストはランダムな配列AでBITインスタンスを初期化し、ランダムな$i$を参照した結果を比較しています。

テストコード
class BinaryIndexedTree:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l)

    def getitem_prefix_diff(self, i: int):
        return self.range_sum(i, i + 1)

    def getitem_one_lsb(self, i: int):
        j = i + 1
        if j & 1:  # lsb = 1
            return self.tree[j]
        return self.range_sum(i, i + 1)

    def getitem_four_lsb(self, i: int):
        j = i + 1
        if j & 1:  # lsb = 1
            return self.tree[j]
        if j & 2:  # lsb = 2
            return self.tree[j] - self.tree[j - 1]
        if j & 4:  # lsb = 4
            return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
        if j & 8:  # lsb = 8
            return self.tree[j] - self.tree[j - 1] - self.tree[j - 2] - self.tree[j - 4]
        return self.range_sum(i, i + 1)

    def getitem_generalize(self, i: int):
        j = i + 1
        res = self.tree[j]
        k = 1
        while j & k == 0:
            res -= self.tree[j - k]
            k <<= 1
        return res


import random
import timeit
import matplotlib.pyplot as plt

# テストパラメータ
N = 10**5  # 配列の長さ
Q = 10**5  # クエリの数
random.seed(42)  # Answer to the Ultimate Question of Life, the Universe, and Everything

# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)

# ランダムなクエリを生成し、test_caseに格納
test_case = [(random.randrange(N)) for _ in range(Q)]


# それぞれのメソッドをtest_caseで速度チェック
def test_getitem_prefix_diff():
    for i in test_case:
        bit.getitem_prefix_diff(i)


def test_getitem_one_lsb():
    for i in test_case:
        bit.getitem_one_lsb(i)


def test_getitem_four_lsb():
    for i in test_case:
        bit.getitem_four_lsb(i)


def test_getitem_generalize():
    for i in test_case:
        bit.getitem_generalize(i)


# loop回の平均を取る
loop = 100
time_getitem_prefix_diff = timeit.timeit(test_getitem_prefix_diff, number=loop)
time_getitem_one_lsb = timeit.timeit(test_getitem_one_lsb, number=loop)
time_getitem_four_lsb = timeit.timeit(test_getitem_four_lsb, number=loop)
time_getitem_generalize = timeit.timeit(test_getitem_generalize, number=loop)

# 結果を棒グラフで表示
methods = ["prefix_diff", "one_lsb", "four_lsb", "generalize"]
times = [
    time_getitem_prefix_diff / loop,
    time_getitem_one_lsb / loop,
    time_getitem_four_lsb / loop,
    time_getitem_generalize / loop,
]

y_positions = [0, 0.6, 1.2, 1.8]  # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green", "red", "purple"], height=0.4)
plt.yticks(y_positions, methods)  # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Get item Comparison")
plt.show()

結果

result1.png

prefis diff(標準的な実装)に比べて計算通りに向上しています(one_lsbは50%の改善、four_lsbは80%程度の改善)。実行環境や要素数によってはlsb4まで一般化を若干上回る場合もありますが、平均的には時間がprefis diff > one_lsb > four_lsb > generalizeとなっていました。うれしい。

なおここまでの実装はtreeのみを使って一点参照することを目的としていましたが、メモリが許すならoriginal_arrを持っておき一点加算の度に更新すれば、一点参照は$O(1)$でできます。
なので今回の実装がより生きるのは次に説明する区間加算に対応したBITでの一点参照のほうかもしれないです。

応用

区間加算・一点参照BITへの適用

一点加算・区間和BITに対してimos法を適用すると、区間加算と一点参照をそれぞれ$O(\log{N})$で行えるBITが実装できます。

区間加算・一点参照BITの実装例(改良前)
class BinaryIndexedTree:
    __slot__ = ("n", "tree")

    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (self.n + 1)

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[1:] = arr
        for i, a in enumerate(arr, 1):
            if i < ins.n:
                ins.tree[i + 1] -= a
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[j] += ins.tree[i]
        return ins

    def _prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def _add(self, i: int, x) -> None:
        """Add x at i (0-indexed)"""
        i += 1
        while i <= self.n:
            self.tree[i] += x
            i += i & -i

    def add_range(self, l: int, r: int, x) -> None:
        """Add x at [l, r) (0-indexed)"""
        self._add(l, x)
        self._add(r, -x)

    def __getitem__(self, i: int):
        return self._prefix_sum(i + 1)

    def __iter__(self):
        yield from (self[i] for i in range(self.n))

    def __str__(self) -> str:
        return f"{self.__class__.__name__}({list(self)})"

一点参照はprefix_sumを一回実行するだけなので、今回の手法での一点参照の効率化は図れません。
ではどこで使えるかというと、配列要素のIterationで今回の手法が使えます(上の実装例の__iter__メソッド)。
naiveな実装の__iter__では、毎回prefix_sumを呼び出して各要素を求めますが、自分より前の区間について繰り返し加算しており少し無駄な計算をしています。

改善版の__iter__は最初res = 0から始めて、隣接するTree上の一点要素を一つ足して次の要素を順番に得ることで実装できます。ここの一点要素取得に今回の手法を用いることで、毎回prefix_sumをする方法より効率的に必要な要素を加算でき、定数倍改善できます。

    def _get_point(self, i: int):
        j = i + 1
        k = 1
        res = self.tree[j]
        while j & k == 0:
            res -= self.tree[j - k]
            k <<= 1
        return res

    def __iter__(self):
        res = 0
        for i in range(self.n):
            res += self._get_point(i)
            yield res

    # __iter__の代わりにtolistを実装してもいい。やってることは同じ
    def tolist(self):
        res = [self.tree[1]]
        for i in range(1, self.n):
            res.append(res[-1] + self._get_point(i))
        return res

_get_pointで今回の手法により一点参照を行い、呼び出しのたびに足していくような実装です。

区間加算・区間和BITへの適用

定数加算をtree[0]で、線形に増加する加算をtree[1]という2本のBITで管理することで、区間加算・区間和BITを実装できることが知られています。

この区間加算・区間和BITにおいても一点参照の改良はできるのでしょうか?実はできます。

区間加算・区間和BITの実装例(改良前)
class BinaryIndexedTree:
    __slot__ = ("n", "tree")

    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]

    @classmethod
    def from_array(cls, arr) -> "BinaryIndexedTree":
        ins = cls(len(arr))
        ins.tree[0][1:] = arr
        for i in range(1, ins.n):
            j = i + (i & -i)
            if j <= ins.n:
                ins.tree[0][j] += ins.tree[0][i]
        return ins


    def _sum_partial(self, p: int, i: int):
        res = 0
        while i > 0:
            res += self.tree[p][i]
            i -= i & -i
        return res

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

    def _add_partial(self, p: int, i: int, x) -> None:
        while i <= self.n:
            self.tree[p][i] += x
            i += i & -i

    def add_range(self, l: int, r: int, x) -> None:
        """Add x at [l, r) (0-indexed)"""
        l += 1
        self._add_partial(0, l, -x * (l - 1))
        self._add_partial(0, r, x * r)
        self._add_partial(1, l, x)
        self._add_partial(1, r, -x)

    def add_point(self, i: int, x) -> None:
        """Add x at i (0-indexed)"""
        self._add_partial(0, i + 1, x)

    def update_point(self, i: int, x) -> None:
        """Update i to x (0-indexed)"""
        self._add_partial(0, i + 1, x - self[i])

    def range_sum(self, l: int, r: int):
        """Range sum of [l, r) (0-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l)

    def __getitem__(self, i: int):
        return self.range_sum(i, i + 1)

    def __iter__(self):
        yield from (self[i] for i in range(self.n))

    def __str__(self) -> str:
        return f"{self.__class__.__name__}({list(self)})"

詳細は省きますが、区間加算・区間和BITのprefix_sum(i)は以下のように実装されます。

    def prefix_sum(self, i: int):
        """Prefix sum of [0, i) (0-indexed)"""
        return self._sum_partial(0, i) + self._sum_partial(1, i) * i

ここで_sum_partial(p, i)は一点加算BITのprefix_sum(i)と同じ動作をし、pで木の種類を指定しています。
range_sum(i, i+1)を式変形すると、

\begin{eqnarray*}
\text{range_sum}(i,\ i+1) & = & \text{prefix_sum}(i+1)-\text{prefix_sum}(i) \\\\
& = & \text{sum_partial}(0,\ i+1) + \text{sum_partial}(1,\ i+1) * (i+1) \\
& - & \text{sum_partial}(0,\ i) - \text{sum_partial}(1,\ i) * i \\\\
& = & \underline{\text{sum_partial}(0,\ i+1) - \text{sum_partial}(0,\ i)} \\
& + & \big\lbrace\underline{\text{sum_partial}(1,\ i+1) - \text{sum_partial}(1,\ i)}\big\rbrace * i \\
& + & \text{sum_partial}(1,\ i+1)
\end{eqnarray*}

と式変形できます。ここで変形後の下線部は、_sum_partial(p, i)が一点加算BITのprefix_sum(i)と同じ動作をすることを考えれば、一点加算BITの一点参照と同じように高速化できます。
よって三段目の$\text{sum_partial}(1,\ i+1)$だけ計算すればよく、元々4回のsum_partialの計算が3回に改善されることになります。わーい。

    def __getitem__(self, i: int):
        j = i + 1
        k = 1
        res0 = self.tree[0][j]
        res1 = self.tree[1][j]
        while j & k == 0:
            res0 -= self.tree[0][j - k]
            res1 -= self.tree[1][j - k]
            k <<= 1
        return res0 + res1 * i + self._sum_partial(1, j)

この実装ではres0は一項目の$\text{sum_partial}(0,\ i+1) - \text{sum_partial}(0,\ i)$を、res1は二項目の$\text{sum_partial}(1,\ i+1) - \text{sum_partial}(1,\ i)$を担っています。それぞれを一点加算BITの時と同様に計算することができます。

まとめと今後の課題

以上のようにrange_sumによる一点参照の実装より高速化することができました。これより早くできる実装はないと思います(あったら教えてください。すごく驚きます)。
着想から実装までを順に追っていったので長文になってしまいましたが、実装を理解していただけたならうれしいです。

ただまだ以下の課題は残っているので、解決できた方はぜひ教えてください。

2次元BITへの適用

1次元のBITを2本持つことで2次元上の一点加算・区間和クエリに対応するBITや、1次元のBITを4本持つことで2次元上の区間加算・区間和クエリに対応するBITが知られています。
これらの2次元BITにも今回の手法が適用できそうな気配がありますが、自分の頭では上手く適用することができませんでした。誰か代わりにお願いします。

References

  1. 車輪の再発明ですが、考察の過程がほかの人の理解の助けになればいいな~というモチベーションで書いてます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?