結論
-
この実装をすると、
prefix_sum
の差による一点参照の実装に比べて定数倍改善できます。 - 区間加算・一点参照BITのIterationを定数倍改善できます。
- 区間加算・区間和BITについても一点参照を定数倍改善できます。
はじめに
この記事は続編記事の「Binary Indexed Treeの区間合計の定数倍改善手法の紹介」の一部下位互換です。続編記事の内容を実装したほうが簡潔になります。
はじめましての人ははじめまして。alumiと申します。
趣味の競プロで使うBinary Indexed Tree(a.k.a. Fenwick Tree。以下BIT)のライブラリをいじっていたら、一般的なprefix_sum
の差から導出する一点参照の実装を定数倍改善できた1ので、考察の過程と実装結果を紹介します。
改良前の一点参照の実装
以下対象とするBITは、一点加算クエリ・区間和導出クエリに対応し、0-indexedで実装しているものとします(区間加算クエリ・区間和クエリについては後述)。また実装はPythonで行います。
BITでは、区間$[0,\ i)$の和を導出するprefix_sum
メソッドを実装し、prefix_sum
の差prefix_sum(r) - prefix_sum(l)
から区間和$[l,\ r)$を導出するrange_sum(l, r)
を実装する形式が一般的です。
range_sum
が実装されていれば、ある一点$i$を参照したい場合はrange_sum(i, i + 1)
とすれば導出できますし、実際これは$O(\log{N})$で動作するので十分高速です。
例えば上図でTree indexが6となっている箇所を一点参照したい場合、1~6までを被覆するブロックの和(濃い紫色)から、1~5までを被覆するブロック(青色)を引けば、1つ分のブロックの値が導出できます。
実装例(改良前)
class BinaryIndexedTree:
__slot__ = ("n", "tree")
def __init__(self, n: int) -> None:
self.n = n
self.tree = [0] * (self.n + 1)
@classmethod
def from_array(cls, arr) -> "BinaryIndexedTree":
ins = cls(len(arr))
ins.tree[1:] = arr
for i in range(1, ins.n):
j = i + (i & -i)
if j <= ins.n:
ins.tree[j] += ins.tree[i]
return ins
def prefix_sum(self, i: int):
"""Prefix sum of [0, i) (0-indexed)"""
res = 0
while i > 0:
res += self.tree[i]
i -= i & -i
return res
def add(self, i: int, x) -> None:
"""Add x at i (0-indexed)"""
i += 1
while i <= self.n:
self.tree[i] += x
i += i & -i
def range_sum(self, l: int, r: int):
"""Range sum of [l, r) (0-indexed)"""
return self.prefix_sum(r) - self.prefix_sum(l)
def __getitem__(self, i: int):
return self.range_sum(i, i + 1)
改良の着想
ただ図を眺めるとわかりますが、6だけがほしい場合は6の担当する区間和ブロックから隣の5だけを引けば求めることができるので、range_sum
実装では少し無駄な計算をしていることがわかります。
上図を一点参照を改善したい気持ちで見ると、図でTree indexとなっている部分が奇数の値は計算するまでもなく直接Treeから参照すればいいことがわかります(それ以外についてはrange_sum
で求める)。奇数は2進数で表したとき一番左のbitが立っているので、奇数かどうかはif j & 1 == 1:
で判定できます。
def __getitem__(self, i: int):
+ j = i + 1
+ if j & 1:
+ return self.tree[i + 1]
return self.range_sum(i, i + 1)
これだけで全体の50%を$O(\log{N})$から$O(1)$に改善できるので、実用上はこれでも十分だと思います。
もう少し欲張って改善しようとすると、次に導出が簡単そうなのは上図でlsb(最下位ビット)が**1*
となっている箇所(下から2段目、2, 6, 10)は、自分のすぐ左隣りのブロックを引くだけでひとつ分の値が計算できそうです。
実際上図のようにlsbが**1*
の段にある6は、自分が担当するブロックからすぐ左隣りの5を引けば良さそうです。
lsbが**1*
かどうかは、if j & 1 == 1:
がFalseとなった後ならif j & 2 == 1
かどうかで判定できるので、Trueなら自分の値から一つ前の値を引いた値を返せばいいです。
def __getitem__(self, i: int):
j = i + 1
if j & 1: # lsb = 1
return self.tree[j]
+ if j & 2: # lsb = 2
+ return self.tree[j] - self.tree[j - 1]
return self.range_sum(i, i + 1)
同様にlsbが*1**
となっている箇所を一点参照したい場合は、自分自身から一つ前と二つ前を引けばいいことが図を眺めるとわかります。
lsbが*1**
かどうかは、j & 4 == 1
かどうかで判定できます。
def __getitem__(self, i: int):
j = i + 1
if j & 1: # lsb = 1
return self.tree[j]
if j & 2: # lsb = 2
return self.tree[j] - self.tree[j - 1]
+ if j & 4: # lsb = 4
+ return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
return self.range_sum(i, i + 1)
ここまでで全体の87.5%(50%+25%+12.5%)を$O(1)$で導出することができるようになりました。
これまでを簡単にまとめると、
- lsbが
***1
→self.tree[j]
- lsbが
**1*
→self.tree[j] - self.tree[j - 1]
- lsbが
*1**
→self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
となっており、引く値のindexが一つずつ下がっていっています。
ではlsbが1***
では自分から三つ目までを引けばいいのでしょうか。実は一つずつの法則はここまでで、これ以降は引く値が2べきで増えていきます(実はこれまでもそう)。
上図からわかる通り、自分が担当するブロックからそれより前の値を順番に引いていくのは同じですが、三つ目の値は自分のindexから4つ前になります(上図のように8なら4)。
lsbが1***
かどうかは、j & 8 == 1
かどうかで判定できます。
def __getitem__(self, i: int):
j = i + 1
if j & 1: # lsb = 1
return self.tree[j]
if j & 2: # lsb = 2
return self.tree[j] - self.tree[j - 1]
if j & 4: # lsb = 4
return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
+ if j & 8: # lsb = 8
+ return self.tree[j] - self.tree[j - 1] - self.tree[j - 2] - self.tree[j - 4]
return self.range_sum(i, i + 1)
一般化
ここまでくればなんとなく法則性が見えてきて、任意の$j$に対して一般化できそうです。つまり参照したいTree indexを$j \ ( = i+1$)とすると、自分が担当するブロックのlsbと引く値のlsbが同じ高さになるまで、自分から2べきだけ離れているブロックを引き続けるという感じで一般化できそうです。
これは実装を見ていただいたほうがわかりやすいと思います。
def __getitem__(self, i: int):
j = i + 1
k = 1
res = self.tree[j]
while j & k == 0:
res -= self.tree[j - k]
k <<= 1
return res
res = self.tree[j]
として自分が担当するブロックで初期化し、j & k == 0
の間、つまり自分が担当するブロックのlsbをk
が超えるまでself.tree[j - k]
をresから引きまくる感じです。
速度テスト
実際どの程度早くなったのかテストしてみましょう。
テストはランダムな配列AでBITインスタンスを初期化し、ランダムな$i$を参照した結果を比較しています。
テストコード
class BinaryIndexedTree:
def __init__(self, n: int) -> None:
self.n = n
self.tree = [0] * (self.n + 1)
@classmethod
def from_array(cls, arr) -> "BinaryIndexedTree":
ins = cls(len(arr))
ins.tree[1:] = arr
for i in range(1, ins.n):
j = i + (i & -i)
if j <= ins.n:
ins.tree[j] += ins.tree[i]
return ins
def prefix_sum(self, i: int):
"""Prefix sum of [0, i) (0-indexed)"""
res = 0
while i > 0:
res += self.tree[i]
i -= i & -i
return res
def range_sum(self, l: int, r: int):
"""Range sum of [l, r) (0-indexed)"""
return self.prefix_sum(r) - self.prefix_sum(l)
def getitem_prefix_diff(self, i: int):
return self.range_sum(i, i + 1)
def getitem_one_lsb(self, i: int):
j = i + 1
if j & 1: # lsb = 1
return self.tree[j]
return self.range_sum(i, i + 1)
def getitem_four_lsb(self, i: int):
j = i + 1
if j & 1: # lsb = 1
return self.tree[j]
if j & 2: # lsb = 2
return self.tree[j] - self.tree[j - 1]
if j & 4: # lsb = 4
return self.tree[j] - self.tree[j - 1] - self.tree[j - 2]
if j & 8: # lsb = 8
return self.tree[j] - self.tree[j - 1] - self.tree[j - 2] - self.tree[j - 4]
return self.range_sum(i, i + 1)
def getitem_generalize(self, i: int):
j = i + 1
res = self.tree[j]
k = 1
while j & k == 0:
res -= self.tree[j - k]
k <<= 1
return res
import random
import timeit
import matplotlib.pyplot as plt
# テストパラメータ
N = 10**5 # 配列の長さ
Q = 10**5 # クエリの数
random.seed(42) # Answer to the Ultimate Question of Life, the Universe, and Everything
# ランダムな要素の配列Aを作成し、BITを初期化
A = [random.randint(-(10**9), 10**9) for _ in range(N)]
bit = BinaryIndexedTree.from_array(A)
# ランダムなクエリを生成し、test_caseに格納
test_case = [(random.randrange(N)) for _ in range(Q)]
# それぞれのメソッドをtest_caseで速度チェック
def test_getitem_prefix_diff():
for i in test_case:
bit.getitem_prefix_diff(i)
def test_getitem_one_lsb():
for i in test_case:
bit.getitem_one_lsb(i)
def test_getitem_four_lsb():
for i in test_case:
bit.getitem_four_lsb(i)
def test_getitem_generalize():
for i in test_case:
bit.getitem_generalize(i)
# loop回の平均を取る
loop = 100
time_getitem_prefix_diff = timeit.timeit(test_getitem_prefix_diff, number=loop)
time_getitem_one_lsb = timeit.timeit(test_getitem_one_lsb, number=loop)
time_getitem_four_lsb = timeit.timeit(test_getitem_four_lsb, number=loop)
time_getitem_generalize = timeit.timeit(test_getitem_generalize, number=loop)
# 結果を棒グラフで表示
methods = ["prefix_diff", "one_lsb", "four_lsb", "generalize"]
times = [
time_getitem_prefix_diff / loop,
time_getitem_one_lsb / loop,
time_getitem_four_lsb / loop,
time_getitem_generalize / loop,
]
y_positions = [0, 0.6, 1.2, 1.8] # Y軸の位置を明示的に指定
plt.barh(y_positions, times, color=["blue", "green", "red", "purple"], height=0.4)
plt.yticks(y_positions, methods) # Y軸のラベルと位置を設定
plt.xlabel("Execution Time (sec)")
plt.title("Get item Comparison")
plt.show()
結果
prefis diff
(標準的な実装)に比べて計算通りに向上しています(one_lsb
は50%の改善、four_lsb
は80%程度の改善)。実行環境や要素数によってはlsb4まで
が一般化
を若干上回る場合もありますが、平均的には時間がprefis diff
> one_lsb
> four_lsb
> generalize
となっていました。うれしい。
なおここまでの実装はtreeのみを使って一点参照することを目的としていましたが、メモリが許すならoriginal_arr
を持っておき一点加算の度に更新すれば、一点参照は$O(1)$でできます。
なので今回の実装がより生きるのは次に説明する区間加算に対応したBITでの一点参照のほうかもしれないです。
応用
区間加算・一点参照BITへの適用
一点加算・区間和BITに対してimos法を適用すると、区間加算と一点参照をそれぞれ$O(\log{N})$で行えるBITが実装できます。
区間加算・一点参照BITの実装例(改良前)
class BinaryIndexedTree:
__slot__ = ("n", "tree")
def __init__(self, n: int) -> None:
self.n = n
self.tree = [0] * (self.n + 1)
@classmethod
def from_array(cls, arr) -> "BinaryIndexedTree":
ins = cls(len(arr))
ins.tree[1:] = arr
for i, a in enumerate(arr, 1):
if i < ins.n:
ins.tree[i + 1] -= a
j = i + (i & -i)
if j <= ins.n:
ins.tree[j] += ins.tree[i]
return ins
def _prefix_sum(self, i: int):
"""Prefix sum of [0, i) (0-indexed)"""
res = 0
while i > 0:
res += self.tree[i]
i -= i & -i
return res
def _add(self, i: int, x) -> None:
"""Add x at i (0-indexed)"""
i += 1
while i <= self.n:
self.tree[i] += x
i += i & -i
def add_range(self, l: int, r: int, x) -> None:
"""Add x at [l, r) (0-indexed)"""
self._add(l, x)
self._add(r, -x)
def __getitem__(self, i: int):
return self._prefix_sum(i + 1)
def __iter__(self):
yield from (self[i] for i in range(self.n))
def __str__(self) -> str:
return f"{self.__class__.__name__}({list(self)})"
一点参照はprefix_sum
を一回実行するだけなので、今回の手法での一点参照の効率化は図れません。
ではどこで使えるかというと、配列要素のIterationで今回の手法が使えます(上の実装例の__iter__
メソッド)。
naiveな実装の__iter__
では、毎回prefix_sum
を呼び出して各要素を求めますが、自分より前の区間について繰り返し加算しており少し無駄な計算をしています。
改善版の__iter__
は最初res = 0
から始めて、隣接するTree上の一点要素を一つ足して次の要素を順番に得ることで実装できます。ここの一点要素取得に今回の手法を用いることで、毎回prefix_sum
をする方法より効率的に必要な要素を加算でき、定数倍改善できます。
def _get_point(self, i: int):
j = i + 1
k = 1
res = self.tree[j]
while j & k == 0:
res -= self.tree[j - k]
k <<= 1
return res
def __iter__(self):
res = 0
for i in range(self.n):
res += self._get_point(i)
yield res
# __iter__の代わりにtolistを実装してもいい。やってることは同じ
def tolist(self):
res = [self.tree[1]]
for i in range(1, self.n):
res.append(res[-1] + self._get_point(i))
return res
_get_point
で今回の手法により一点参照を行い、呼び出しのたびに足していくような実装です。
区間加算・区間和BITへの適用
定数加算をtree[0]
で、線形に増加する加算をtree[1]
という2本のBITで管理することで、区間加算・区間和BITを実装できることが知られています。
この区間加算・区間和BITにおいても一点参照の改良はできるのでしょうか?実はできます。
区間加算・区間和BITの実装例(改良前)
class BinaryIndexedTree:
__slot__ = ("n", "tree")
def __init__(self, n: int) -> None:
self.n = n
self.tree = [[0] * (self.n + 1), [0] * (self.n + 1)]
@classmethod
def from_array(cls, arr) -> "BinaryIndexedTree":
ins = cls(len(arr))
ins.tree[0][1:] = arr
for i in range(1, ins.n):
j = i + (i & -i)
if j <= ins.n:
ins.tree[0][j] += ins.tree[0][i]
return ins
def _sum_partial(self, p: int, i: int):
res = 0
while i > 0:
res += self.tree[p][i]
i -= i & -i
return res
def prefix_sum(self, i: int):
"""Prefix sum of [0, i) (0-indexed)"""
return self._sum_partial(0, i) + self._sum_partial(1, i) * i
def _add_partial(self, p: int, i: int, x) -> None:
while i <= self.n:
self.tree[p][i] += x
i += i & -i
def add_range(self, l: int, r: int, x) -> None:
"""Add x at [l, r) (0-indexed)"""
l += 1
self._add_partial(0, l, -x * (l - 1))
self._add_partial(0, r, x * r)
self._add_partial(1, l, x)
self._add_partial(1, r, -x)
def add_point(self, i: int, x) -> None:
"""Add x at i (0-indexed)"""
self._add_partial(0, i + 1, x)
def update_point(self, i: int, x) -> None:
"""Update i to x (0-indexed)"""
self._add_partial(0, i + 1, x - self[i])
def range_sum(self, l: int, r: int):
"""Range sum of [l, r) (0-indexed)"""
return self.prefix_sum(r) - self.prefix_sum(l)
def __getitem__(self, i: int):
return self.range_sum(i, i + 1)
def __iter__(self):
yield from (self[i] for i in range(self.n))
def __str__(self) -> str:
return f"{self.__class__.__name__}({list(self)})"
詳細は省きますが、区間加算・区間和BITのprefix_sum(i)
は以下のように実装されます。
def prefix_sum(self, i: int):
"""Prefix sum of [0, i) (0-indexed)"""
return self._sum_partial(0, i) + self._sum_partial(1, i) * i
ここで_sum_partial(p, i)
は一点加算BITのprefix_sum(i)
と同じ動作をし、p
で木の種類を指定しています。
range_sum(i, i+1)
を式変形すると、
\begin{eqnarray*}
\text{range_sum}(i,\ i+1) & = & \text{prefix_sum}(i+1)-\text{prefix_sum}(i) \\\\
& = & \text{sum_partial}(0,\ i+1) + \text{sum_partial}(1,\ i+1) * (i+1) \\
& - & \text{sum_partial}(0,\ i) - \text{sum_partial}(1,\ i) * i \\\\
& = & \underline{\text{sum_partial}(0,\ i+1) - \text{sum_partial}(0,\ i)} \\
& + & \big\lbrace\underline{\text{sum_partial}(1,\ i+1) - \text{sum_partial}(1,\ i)}\big\rbrace * i \\
& + & \text{sum_partial}(1,\ i+1)
\end{eqnarray*}
と式変形できます。ここで変形後の下線部は、_sum_partial(p, i)
が一点加算BITのprefix_sum(i)
と同じ動作をすることを考えれば、一点加算BITの一点参照と同じように高速化できます。
よって三段目の$\text{sum_partial}(1,\ i+1)$だけ計算すればよく、元々4回のsum_partial
の計算が3回に改善されることになります。わーい。
def __getitem__(self, i: int):
j = i + 1
k = 1
res0 = self.tree[0][j]
res1 = self.tree[1][j]
while j & k == 0:
res0 -= self.tree[0][j - k]
res1 -= self.tree[1][j - k]
k <<= 1
return res0 + res1 * i + self._sum_partial(1, j)
この実装ではres0
は一項目の$\text{sum_partial}(0,\ i+1) - \text{sum_partial}(0,\ i)$を、res1
は二項目の$\text{sum_partial}(1,\ i+1) - \text{sum_partial}(1,\ i)$を担っています。それぞれを一点加算BITの時と同様に計算することができます。
まとめと今後の課題
以上のようにrange_sum
による一点参照の実装より高速化することができました。これより早くできる実装はないと思います(あったら教えてください。すごく驚きます)。
着想から実装までを順に追っていったので長文になってしまいましたが、実装を理解していただけたならうれしいです。
ただまだ以下の課題は残っているので、解決できた方はぜひ教えてください。
2次元BITへの適用
1次元のBITを2本持つことで2次元上の一点加算・区間和クエリに対応するBITや、1次元のBITを4本持つことで2次元上の区間加算・区間和クエリに対応するBITが知られています。
これらの2次元BITにも今回の手法が適用できそうな気配がありますが、自分の頭では上手く適用することができませんでした。誰か代わりにお願いします。
References
-
Binary Indexed Trees | Read the actual frequency at a position | topcoder
- 若干違う実装について言及。
-
Binary indexed tree | Project Nayuki
- Treeを0-indexedで構築した場合の実装例。1-indexedとどっちが早いかは未確認。
-
車輪の再発明ですが、考察の過程がほかの人の理解の助けになればいいな~というモチベーションで書いてます。 ↩