1
1

More than 1 year has passed since last update.

PythonでDisjoint Sparse Tableを書いたよ。

Last updated at Posted at 2023-07-23

PythonでDisjoint Sparse Tableをかきました。
好きにつかってやってください!

コード

長いので折りたたみにしてます。

コード詳細
class DisjointSparseTable:
    """Disjoint Sparse Table

    リストに対して、事前構築O(NlogN)、区間演算O(1)でできます。
    演算funcは、結合則を満たすものだけです。
        結合則: func(a, b) == func(b, a)  (順番は関係なし)

    Attributes:
        lst (list): 対象のリスト
        func (Callable[[int,int],int]): 演算 (xor, gcd, min, max, ...)

    Examples:
        >>> dst = DisjointSparseTable([9, 1, 6, 10, 2], min)
        >>> dst.prod(0, 4)
        1
        >>> dst.prod(2, 4)
        2

        >>> def xor(i, j): return i ^ j
        >>> dst = DisjointSparseTable([9, 1, 6, 10, 2], xor)
        >>> dst.prod(0, 4)  # 9^1^6^10^2=6
        6
        >>> dst.prod(0, 2)  # 9^1^6=14
        14
    """

    def __init__(self, lst: list, func: "function") -> None:
        length = 1 << (len(lst) - 1).bit_length()
        self.my_lst = lst + [0] * (length - len(lst))
        self.__func = func

        # テーブルの作成
        self.tables = []
        dist = length >> 1
        gc = 2  # グループの数
        while dist:
            now_table = [None] * length
            left_start = 0
            left, right = 0, dist - 1
            for _ in range(gc):
                if left_start:
                    val = self.my_lst[left]
                    for i in range(left, right + 1):
                        if i != left:
                            val = func(val, self.my_lst[i])
                        now_table[i] = val
                else:
                    # right_start
                    val = self.my_lst[right]
                    for i in range(right, left - 1, -1):
                        if i != right:
                            val = func(val, self.my_lst[i])
                        now_table[i] = val
                left += dist
                right += dist
                left_start ^= 1
            self.tables.append(now_table)
            dist >>= 1
            gc <<= 1

        # self.bit_lengthリストの作成
        self.bit_length = [0]
        for i in range(1, (len(lst) - 1).bit_length() + 1):
            for s in range(1 << (i - 1)):
                self.bit_length.append(i)

    def prod(self, left: int, right: int) -> int:
        """区間演算

        閉区間 [left, right] で演算 func を行った結果を返します。

        Args:
            left (int): 区間の左端のインデックス
            right (int): 区間の右端のインデックス

        Returns:
            int: 閉区間 [left, right] で演算 func を行った結果

        Note:
            閉区間 [left, right] なのに注意してね。

        """
        assert 0 <= left <= right < len(self.my_lst), \
            "0<=left<=right<len(self.my_lst), left={0}, right={1}" \
                .format(left, right)
        if left == right:
            return self.my_lst[left]

        # table_idx = len(self.tables) - (left ^ right).bit_length()
        table_idx = len(self.tables) - self.bit_length[left ^ right]
        left_el = self.tables[table_idx][left]
        right_el = self.tables[table_idx][right]
        return self.__func(left_el, right_el)

つかいかた

docstringに書いてある通りですが、別の例も書いておきます。

# 上のコードをコピペしてから書いてね。

import math

dst = DisjointSparseTable([10, 30, 7, 21, 6], lambda x, y: math.gcd(x, y))
print(dst.prod(0, 1))  # 10
print(dst.prod(2, 3))  # 7
print(dst.prod(0, 4))  # 1


def mul(x, y): return x * y
dst2 = DisjointSparseTable([10, 30, 7, 21, 6], mul)
print(dst2.prod(0, 2))  # 2100
print(dst2.prod(3, 4))  # 126
print(dst2.prod(0, 4))  # 264600

ACコード

ABC189-C - Mandarin Orange

参考URL

ほぼここに書いてあるのをコードにしただけです!!ありがとうございます!!!
noshi91のメモ - Disjoint Sparse Table と セグ木に関するポエム

編集履歴

2023/07/23 bit_lengthの前計算で、クエリをO(1)で処理できるように改善しました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1