5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pythonで2Dセグメント木実装

Last updated at Posted at 2021-01-22

#はじめに
pythonで実装した2Dセグメント木について記事が見当たらなかったので紹介しようと思います。

#参考文献
https://qiita.com/tomato1997/items/da9a7a73f2301aa48896
https://www.hamayanhamayan.com/entry/2017/12/09/015937
http://algoogle.hadrori.jp/algorithm/2d-segment-tree.html
https://ei1333.github.io/luzhiled/snippets/structure/segment-tree.html

#2Dセグメント木とは
ある矩形領域に対して、和や最小最大などなどの値をlogオーダーで取り出し可能とするものです。ある一点の値の変更もlogオーダーで可能です。

例として4×4の配列を扱います。セグメント木に乗せる関数はminとしました。(以下、0-indexedで説明します)
###木の構築
言葉での説明は難しいので図を見てください。右下の黒枠が元の配列です。親子に対応する部分の一部を色で塗ってみました。左または上の方が親要素になっています。
図1.jpg
木の構築はO(NM)です。
###値の更新
例えば(1,2)の値を0にする場合、以下の図のような部分が更新されます。
図2.jpg
値の更新はO(logNlogM)です。
###区間の値の取得
[0,3)×[0,3)を取得する場合、見なければいけない領域は図の色を付けた部分です。
図3.jpg
この場合は[0,2)×[0,2)、[0,2)×[2,3)、[2,3)×[0,2)、[2,3)×[2,3)の4か所だということが分かります。
値の取得はO(logNlogM)です。

#コード
上の考察を用いてpythonにて実装しました。

segki2D.py
import math
class SegTree2D():
    DEFAULT = {
        'min': 1 << 60,
        'max': -(1 << 60),
        'sum': 0,
        'prd': 1,
        'gcd': 0,
        'lmc': 1,
        '^': 0,
        '&': (1 << 60) - 1,
        '|': 0,
    }

    FUNC = {
        'min': min,
        'max': max,
        'sum': (lambda x, y: x + y),
        'prd': (lambda x, y: x * y),
        'gcd': math.gcd,
        'lmc': (lambda x, y: (x * y) // math.gcd(x, y)),
        '^': (lambda x, y: x ^ y),
        '&': (lambda x, y: x & y),
        '|': (lambda x, y: x | y),
    }

    def __init__(self,ls2D, mode='min', func=None, default=None):
        """
        要素ls2D, 関数mode (min,max,sum,prd(product),gcd,lmc,^,&,|)
        func,defaultを指定すれば任意の関数、単位元での計算が可能
        """
        N = len(ls2D)
        M = len(ls2D[0])
        if default == None:
            self.default = self.DEFAULT[mode]
        else:
            self.default = default
        if func == None:
            self.func = self.FUNC[mode]
        else:
            self.func = func
        self.N = N
        self.M = M
        self.KN = (N - 1).bit_length()
        self.KM = (M - 1).bit_length()
        self.N2 = 1 << self.KN
        self.M2 = 1 << self.KM
        self.dat = [[self.default] * (2**(self.KM + 1)) for i in range(2**(self.KN + 1))]
        for i in range(self.N):
            for j in range(self.M):
                self.dat[self.N2 + i][self.M2 + j] = ls2D[i][j]
        self.build()

    def build(self):
        for j in range(self.M):
            for i in range(self.N2 - 1, 0, -1):
                self.dat[i][self.M2 + j] = self.func(self.dat[i << 1][self.M2 + j], self.dat[i << 1 | 1][self.M2 + j])
        for i in range(2**(self.KN + 1)):
            for j in range(self.M2 - 1, 0, -1):
                self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])

    def leafvalue(self, x,y):  # (x,y)番目の値の取得
        return self.dat[x + self.N2][y + self.M2]

    def update(self, x, y, value):  # (x,y)の値をvalueに変える
        i = x + self.N2
        j = y + self.M2
        self.dat[i][j] = value
        while j > 1:
            j >>= 1
            self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])
        j = y + self.M2
        while i > 1:
            i >>= 1
            self.dat[i][j] = self.func(self.dat[i << 1][j], self.dat[i << 1 | 1][j])
            while j > 1:
                j >>= 1
                self.dat[i][j] = self.func(self.dat[i][j << 1], self.dat[i][j << 1 | 1])
            j = y + self.M2
        return

    def query(self, Lx, Rx, Ly, Ry):  # [Lx,Rx)×[Ly,Ry)の区間取得
        Lx += self.N2
        Rx += self.N2
        Ly += self.M2
        Ry += self.M2
        vLx = self.default
        vRx = self.default       
        while Lx < Rx:
            if Lx & 1:
                vLy = self.default
                vRy = self.default
                Ly1 = Ly
                Ry1 = Ry
                while Ly1 < Ry1:
                    if Ly1 & 1:
                        vLy = self.func(vLy, self.dat[Lx][Ly1])
                        Ly1 += 1
                    if Ry1 & 1:
                        Ry1 -= 1
                        vRy = self.func(self.dat[Lx][Ry1], vRy)
                    Ly1 >>= 1
                    Ry1 >>= 1
                vy = self.func(vLy, vRy)
                vLx = self.func(vLx,vy)
                Lx += 1
            if Rx & 1:
                Rx -= 1
                vLy = self.default
                vRy = self.default
                Ly1 = Ly
                Ry1 = Ry
                while Ly1 < Ry1:
                    if Ly1 & 1:
                        vLy = self.func(vLy, self.dat[Rx][Ly1])
                        Ly1 += 1
                    if Ry1 & 1:
                        Ry1 -= 1
                        vRy = self.func(self.dat[Rx][Ry1], vRy)
                    Ly1 >>= 1
                    Ry1 >>= 1 
                vy = self.func(vLy, vRy)               
                vRx = self.func(vy, vRx)
            Lx >>= 1
            Rx >>= 1
        return self.func(vLx, vRx)

###使用例
AOJ School of Killifish 値の更新はない問題です。
http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1068
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=5153232#1
Atcoder ABC106D AtCoder Express 2 オーバーキル感ありますが当然普通の二次元区間和も扱えます
https://atcoder.jp/contests/abc106/tasks/abc106_d
https://atcoder.jp/contests/abc106/submissions/19563643

#おわりに
ここまでお読みいただきありがとうございました。私自身競技プログラミング初学者ですので、間違い等ございましたらご指摘よろしくお願いいたします。また、LGTMしていただけると励みになります。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?