search
LoginSignup
9
Help us understand the problem. What are the problem?

posted at

updated at

【競プロ専用】PythonでMultiSetを今度こそ

簡単軽量で競プロ攻略に十分な、Python版MultiSetを提供します。

AtCoderでは、最近(2022年)、Pythonには存在しないMultiSetを前提とした問題が、繰り返し出題されています。この暴挙により、Python、PyPy3を使うプレイヤーが、死屍累々たるありさまです。一部の上位プレイヤーは、既存のMultiSetライブラリをググって貼り付けたり、自分で用意したりして、対応しているようです。

しかし、これには、以下の課題があります。

  • 平衡二分木を使ったMultiSetの実装は、複雑長大であり、応用含めた使い方がわかりにくい。
  • heapqやセグメント木(BIT含む)を使った実装は、制限条件が多すぎて、競プロ問題に適用しにくい。

よって、この記事では、BITを使った簡単実装でありながら、競プロ問題への応用性を備えたMultiSetを提供します。具体的には以下の特徴を持ちます。

  • 容易に理解可能な約50行の軽量実装。(別途、BITの実装が必要です)
  • Python使いにも理解しやすい、Python set、list、bisectに似たメソッドサポート。
  • 簡易MultiSetの弱点であった「座標圧縮」を標準サポート。
  • レアな問題で必要になる、複数の同じ数を一度にadd、remove、ouuntする関数をサポート。さらにこれらの高速化のため、BITに加えてリストも内部利用。
  • 原理的に、インタラクティブ問題以外のMultiSet問題に、全て対応可能。

1. メソッドの説明

mset = MultiSet(n=0, compress=[], multi=True)

  • MultiSetを作成します。
  • 各引数は省略可です。nは引数名を指定不要です。ncompressのどちらかは指定してください。
    • 例) mset = MultiSet(10) # 0〜9の値を取りうるマルチセットを作成
    • 例) mset = MultiSet(compress=A) # Aで座圧するマルチセットを作成
    • 例) mset = MultiSet(compress=A, multi=False) # Aで座圧するOrderedSetを作成
  • compressを指定しない場合は、MultiSetの取りうる値は0〜n-1に限定されます。この場合、nは、MultiSetの取りうる値の種類の個数になります。
  • compressは座標圧縮する元となる整数リストです。compressを指定すると、nは無視されて、MultiSetの取りうる値はcompressの値すべてになります。
  • multiTrue(デフォルト)の場合、MultiSet動作(同じ値を複数保持可能)になります。Falseの場合、OrderedSet動作(順序は持つが、Pythonのsetと同様に同じ値は同一視する)になります。

mset.add(x, n=1)

  • MultiSetに値xn個追加します。nは省略すると1となります。OrderedSet動作の場合、n=1以外を指定するとエラーになります。
  • 座圧している場合は、xは座圧要素のどれかである必要があります。それ以外のxを指定すると例外となります。
  • $O(\log n)$ で動作します。

mset.remove(x, n=1)

  • MultiSetから値xn個削除します。nは省略すると1となります。OrderedSet動作の場合、n=1以外を指定するとエラーになります。
  • xはMultiSetに現在n個以上含まれる要素である必要があります。それ以外のxnを指定すると例外となります。
  • $O(\log n)$ で動作します。

print(mset)

  • MultiSetを表示します。
  • $O(n \log n)$ で動作します。

list(mset)

  • MultiSetをソート済リストに変換します。
  • $O(n \log n)$ で動作します。

set(mset)

  • MultiSetを集合型に変換します。同じ要素は同一視されます。
  • $O(n \log n)$ で動作します。

len(mset)

  • MultiSetの要素数を返します。
  • $O(1)$ で動作します。

mset.count(x)

  • MultiSetの要素xの要素数を返します。
  • $O(1)$ で動作します。

mset[idx]

  • MultiSetのidx番目の要素を返します。
  • idxは、MultiSetの要素数をNとした場合、-N以上N - 1未満での整数である必要があります。それ以外のidxを指定すると例外となります。Pythonのリストと同様に、マイナスのidxは、末尾からの逆順となります。
  • $O(\log n)$ で動作します。

x in mset

  • MultiSetにxが含まれているか判定します。
  • xは任意の整数を指定可能です。
  • $O(\log n)$ で動作します。

x not in mset

  • MultiSetにxが含まれていないことを判定します。
  • xは任意の整数を指定可能です。
    • $O(\log n)$ で動作します。

mset.bisect_left(x)

  • MultiSetをソート済リストとみなしてbisect_leftx以上の最小の値のインデックス)を返します。
  • xは任意の整数を指定可能です。
  • $O(\log n)$ で動作します。

mset.bisect_right(x)

  • MultiSetをソート済リストとみなしてbisect_rightxを超える最小の値のインデックス)を返します。
  • xは任意の整数を指定可能です。
  • $O(\log n)$ で動作します。

2. 利用例

A = [0, 1, 2, 3, 4, 100, 10000]
mset = MultiSet(compress=A)
mset.add(2)
mset.add(4)
mset.add(100, 2)
mset.add(10000)
print(mset)
# MultiSet {2, 4, 100, 100, 10000}
print(5 in mset)
# False
print(100 in mset)
# True
print(len(mset))
# 5
print(mset.count(100))
# 2
print(mset.bisect_left(6))
# 2
print(mset.bisect_left(100))
# 2
print(mset.bisect_right(100))
# 4

3. 応用

実際に、直近のMultiSet問題を解いてみます。

3.1. ABC217-D問題

ABC217-D - Cutting Woods

クエリーによって少しずつMultiSetが成長する中で、xを含む区間を求める問題です。
最初に取りうるxの値全てを抽出して座圧します。

PyPy3でのACタイムは約700msです。(実行時間制限: 2秒)

# 入力
L, Q = map(int, input().split())
query = [list(map(int, input().split())) for _ in range(Q)]
# c==1の時の値xを全て抽出して、両端を加えて、OrderedSetを座圧初期化
oset = MultiSet(compress=[0] + [x for c, x in query if c == 1] + [L], multi=False)
# 両端を設定
oset.add(0)
oset.add(L)
# クエリーを1つずつ処理する
for c, x in query:
    if c == 1:
        oset.add(x)
    else:
        i = oset.bisect_left(x)
        print(oset[i] - oset[i - 1])

3.2. ABC241-D問題

ABC241-D - Sequence Query

クエリーによって少しずつMultiSetが成長する中で、x以上(以下)k番目の値を求める問題です。最初に取りうるxの値全てを抽出して座圧します。

PyPy3でのACタイムは約600msです。(実行時間制限: 2秒)

# 入力
Q = int(input())
query = [list(map(int, input().split())) for _ in range(Q)]
# 値xを全て抽出して、MultiSetを座圧初期化
X = [query[i][1] for i in range(Q) if query[i][0] == 1]  
mset = MultiSet(compress=X)
# クエリーを1つずつ処理する
for q in query:
    t, x = q[0], q[1]
    if len(q) == 3: k = q[2]
    if t == 1:
        mset.add(x)
    elif t == 2:
        i = mset.bisect_right(x) - 1  # x以下の最大index
        # k番目が存在すれば出力、しなければ-1
        print(mset[i - k + 1] if i - k + 1 >= 0 else -1)
    else:
        i = mset.bisect_left(x)       # x以上の最小index
        # k番目が存在すれば出力、しなければ-1
        print(mset[i + k - 1] if i + k - 1 < len(mset) else -1)

3.3. ABC245-E問題

ABC245-E - Wrapping Chocolate

応用問題です。縦サイズでチョコと箱をソートしておき、縦サイズでの比較で箱に入る可能性があるチョコをMultiSetに貯めておき、もっともギリギリで入るチョコ(MultiSetの中で箱の横サイズ以下の最大のチョコ)を選択して取り出し、箱に入れていきます。予め、MultiSetに入れて比較するチョコの横サイズBで座圧しておきます。

PyPy3でのACタイムは約1500msです。(実行時間制限: 4秒)

# 入力
N, M = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = list(map(int, input().split()))
D = list(map(int, input().split()))
# A, C でソート
np_T = lambda x: [list(x) for x in zip(*x)]
A, B = np_T(sorted([(a, b) for a, b in zip(A, B)]))
C, D = np_T(sorted([(c, d) for c, d in zip(C, D)]))
# Bで座圧しておく
mset = MultiSet(compress=B)
n = m = 0
while m < M:
    if n < N and A[n] <= C[m]:   # チョコが箱に入れられる -> Bをmsetに入れておく
        mset.add(B[n])
        n += 1
    else:          # 箱に入る可能性ありは全て見た -> msetから最適なものを取り出す
        i = mset.bisect_right(D[m])  # 箱に入る最大のチョコを探す
        if i > 0:  # 対象が存在したら削除(箱に入れる)
            mset.remove(mset[i - 1])
        m += 1     # 次の箱を見る(対象が存在しないなら、この箱は無視)
# Bが全て箱に入ったらYes
print('Yes' if n == N and len(mset) == 0 else 'No')

3.4. ARC140-B問題

ARC140-B - Shorten ARC

ARCにおいても、Multisetを使う問題が出ています。この問題は、Multisetを使わなくても天才的解法が存在しますが、Multisetを使うと難易度が下がります。

PyPy3でのACタイムは約200msです。(実行時間制限: 2秒)

いわゆる「アルゴリズムで殴る」ことができます。

N = int(input())
S = input()
# 'ARC'のネスト数を数えて、msetに入れる
nest = 0
mset = MultiSet(N + 1)
for i in range(N):
    if S[i - 2: i + 1] == 'ARC':
        nest = 1
    elif nest > 0 and S[i] == 'C' and i - nest * 2 - 2 >= 0 and S[i - nest * 2 - 2] == 'A':
        nest += 1
    elif nest > 0:
        mset.add(nest)
        nest = 0
if nest > 0:
    mset.add(nest)

ans = 0
odd = True
while len(mset) > 0:
    if odd:  # 最大のネスト数を1つ減らす
        x = mset[-1]
        mset.remove(x)
        if x > 1:
            mset.add(x - 1)
    else:    # 最小のネスト数を削除する
        x = mset[0]
        mset.remove(x)
    ans += 1
    odd = not odd
print(ans)

3.5. ARC253-C問題

ARC253-C - Max - Min Query

ついにABC-C問題でmultisetを要求される、暗黒時代になりました。しかも、countn個まとめてremoveする機能を使う、レア問題です。

Python向けmultiset実装の多くが、この問題には敗れ去ったものと思います。筆者も、本記事の実装(以前の版)を使いTLEしました。しかしながら、軽量実装の利点を活かして、本問題取り組み中にremoveの複数要素対応を追加実装して、ACにこぎつけました。

PyPy3でのACタイムは約700msです。(実行時間制限: 2秒)

Q = int(input())
query = [list(map(int, input().split())) for _ in range(Q)]

# 座圧候補をリストアップ
xlist = []
for t in query:
    if t[0] == 1 or t[0] == 2:
        xlist.append(t[1])
# 座圧してmultisetを作成 
mset = MultiSet(compress=xlist)
# クエリーを素直に解く
for t in query:
    if t[0] == 1:
        x = t[1]
        mset.add(x)
    elif t[0] == 2:
        x, c = t[1:]
        count_x = mset.count(x)
        mset.remove(x, min(c, count_x))
    else:
        print(mset[-1] - mset[0])

4. 実装

MultiSetの実装は以下です。

2本のセグメント木を使って実装する例をよく見かけますが、BIT1本のみで実装できたため、Python・PyPy3としては、比較的高速軽量に動作します。

なお、別途、二分探索をサポートしているBITの実装が必要です。こちらを参照してclass BITを貼り付けてください。

import bisect
class MultiSet:
    # n: サイズ、compress: 座圧対象list-likeを指定(nは無効)
    # multi: マルチセットか通常のOrderedSetか
    def __init__(self, n=0, *, compress=[], multi=True):
        self.multi = multi
        self.inv_compress = sorted(set(compress)) if len(compress) > 0 else [i for i in range(n)]
        self.compress = {k: v for v, k in enumerate(self.inv_compress)}
        self.counter_all = 0
        self.counter = [0] * len(self.inv_compress)
        self.bit = BIT(len(self.inv_compress))

    def add(self, x, n=1):     # O(log n)
        if not self.multi and n != 1: raise KeyError(n)
        x = self.compress[x]
        count = self.counter[x]
        if count == 0 or self.multi:  # multiなら複数カウントできる
            self.bit.add(x + 1, n)
            self.counter_all += n
            self.counter[x] += n

    def remove(self, x, n=1):  # O(log n)
        if not self.multi and n != 1: raise KeyError(n)
        x = self.compress[x]
        count = self.bit.get(x + 1)
        if count < n: raise KeyError(x)
        self.bit.add(x + 1, -n)
        self.counter_all -= n
        self.counter[x] -= n

    def __repr__(self):
        return f'MultiSet {{{(", ".join(map(str, list(self))))}}}'

    def __len__(self):         # oprator len: O(1)
        return self.counter_all

    def count(self, x):        # O(1)
        return self.counter[self.compress[x]]

    def __getitem__(self, i):  # operator []: O(log n)
        if i < 0: i += len(self)
        x = self.bit.lower_bound(i + 1)
        if x > self.bit.n: raise IndexError('list index out of range')
        return self.inv_compress[x - 1]

    def __contains__(self, x): # operator in: O(log n)
        return self.bit.get(self.compress.get(x, self.bit.n) + 1, 0) > 0

    def bisect_left(self, x):  # O(log n)
        return self.bit.sum(bisect.bisect_left(self.inv_compress, x))

    def bisect_right(self, x): # O(log n)
        return self.bit.sum(bisect.bisect_right(self.inv_compress, x))

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
9
Help us understand the problem. What are the problem?