#はじめに
pythonで実装した2Dセグメント木について記事が見当たらなかったので紹介しようと思います。
#参考文献
https://qiita.com/tomato1997/items/da9a7a73f2301aa48896
https://www.hamayanhamayan.com/entry/2017/12/09/015937
http://algoogle.hadrori.jp/algorithm/2d-segment-tree.html
https://ei1333.github.io/luzhiled/snippets/structure/segment-tree.html
#2Dセグメント木とは
ある矩形領域に対して、和や最小最大などなどの値をlogオーダーで取り出し可能とするものです。ある一点の値の変更もlogオーダーで可能です。
例として4×4の配列を扱います。セグメント木に乗せる関数はminとしました。(以下、0-indexedで説明します)
###木の構築
言葉での説明は難しいので図を見てください。右下の黒枠が元の配列です。親子に対応する部分の一部を色で塗ってみました。左または上の方が親要素になっています。
木の構築はO(NM)です。
###値の更新
例えば(1,2)の値を0にする場合、以下の図のような部分が更新されます。
値の更新はO(logNlogM)です。
###区間の値の取得
[0,3)×[0,3)を取得する場合、見なければいけない領域は図の色を付けた部分です。
この場合は[0,2)×[0,2)、[0,2)×[2,3)、[2,3)×[0,2)、[2,3)×[2,3)の4か所だということが分かります。
値の取得はO(logNlogM)です。
#コード
上の考察を用いてpythonにて実装しました。
import math
class SegTree2D():
DEFAULT = {
'min': 1 << 60,
'max': -(1 << 60),
'sum': 0,
'prd': 1,
'gcd': 0,
'lmc': 1,
'^': 0,
'&': (1 << 60) - 1,
'|': 0,
}
FUNC = {
'min': min,
'max': max,
'sum': (lambda x, y: x + y),
'prd': (lambda x, y: x * y),
'gcd': math.gcd,
'lmc': (lambda x, y: (x * y) // math.gcd(x, y)),
'^': (lambda x, y: x ^ y),
'&': (lambda x, y: x & y),
'|': (lambda x, y: x | y),
}
def __init__(self,ls2D, mode='min', func=None, default=None):
"""
要素ls2D, 関数mode (min,max,sum,prd(product),gcd,lmc,^,&,|)
func,defaultを指定すれば任意の関数、単位元での計算が可能
"""
N = len(ls2D)
M = len(ls2D[0])
if default == None:
self.default = self.DEFAULT[mode]
else:
self.default = default
if func == None:
self.func = self.FUNC[mode]
else:
self.func = func
self.N = N
self.M = M
self.KN = (N - 1).bit_length()
self.KM = (M - 1).bit_length()
self.N2 = 1 << self.KN
self.M2 = 1 << self.KM
self.dat = [[self.default] * (2**(self.KM + 1)) for i in range(2**(self.KN + 1))]
for i in range(self.N):
for j in range(self.M):
self.dat[self.N2 + i][self.M2 + j] = ls2D[i][j]
self.build()
def build(self):
for j in range(self.M):
for i in range(self.N2 - 1, 0, -1):
self.dat[i][self.M2 + j] = self.func(self.dat[i << 1][self.M2 + j], self.dat[i << 1 | 1][self.M2 + j])
for i in range(2**(self.KN + 1)):
for j in range(self.M2 - 1, 0, -1):
self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])
def leafvalue(self, x,y): # (x,y)番目の値の取得
return self.dat[x + self.N2][y + self.M2]
def update(self, x, y, value): # (x,y)の値をvalueに変える
i = x + self.N2
j = y + self.M2
self.dat[i][j] = value
while j > 1:
j >>= 1
self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])
j = y + self.M2
while i > 1:
i >>= 1
self.dat[i][j] = self.func(self.dat[i << 1][j], self.dat[i << 1 | 1][j])
while j > 1:
j >>= 1
self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])
j = y + self.M2
return
def query(self, Lx, Rx, Ly, Ry): # [Lx,Rx)×[Ly,Ry)の区間取得
Lx += self.N2
Rx += self.N2
Ly += self.M2
Ry += self.M2
vLx = self.default
vRx = self.default
while Lx < Rx:
if Lx & 1:
vLy = self.default
vRy = self.default
Ly1 = Ly
Ry1 = Ry
while Ly1 < Ry1:
if Ly1 & 1:
vLy = self.func(vLy, self.dat[Lx][Ly1])
Ly1 += 1
if Ry1 & 1:
Ry1 -= 1
vRy = self.func(self.dat[Lx][Ry1], vRy)
Ly1 >>= 1
Ry1 >>= 1
vy = self.func(vLy, vRy)
vLx = self.func(vLx,vy)
Lx += 1
if Rx & 1:
Rx -= 1
vLy = self.default
vRy = self.default
Ly1 = Ly
Ry1 = Ry
while Ly1 < Ry1:
if Ly1 & 1:
vLy = self.func(vLy, self.dat[Rx][Ly1])
Ly1 += 1
if Ry1 & 1:
Ry1 -= 1
vRy = self.func(self.dat[Rx][Ry1], vRy)
Ly1 >>= 1
Ry1 >>= 1
vy = self.func(vLy, vRy)
vRx = self.func(vy, vRx)
Lx >>= 1
Rx >>= 1
return self.func(vLx, vRx)
###使用例
AOJ School of Killifish 値の更新はない問題です。
http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1068
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=5153232#1
Atcoder ABC106D AtCoder Express 2 オーバーキル感ありますが当然普通の二次元区間和も扱えます
https://atcoder.jp/contests/abc106/tasks/abc106_d
https://atcoder.jp/contests/abc106/submissions/19563643
#おわりに
ここまでお読みいただきありがとうございました。私自身競技プログラミング初学者ですので、間違い等ございましたらご指摘よろしくお願いいたします。また、LGTMしていただけると励みになります。