2023年AtCoder言語アップデートにより、MultiSet機能を含むsortedcontainersライブラリが使えるようになりました。本記事は、2023年AtCoder言語アップデートの前の情報です。
簡単軽量で競プロ攻略に十分な、Python版MultiSetを提供します。
- 2023/1/31:
count()
に座圧以外の数値を指定するとエラーになることの修正をしました。 - 2023/1/31:
in
/not in
を$O(1)$にしました。
AtCoderでは、最近(2022年)、Pythonには存在しないMultiSetを前提とした問題が、繰り返し出題されています。この暴挙により、Python、PyPy3を使うプレイヤーが、死屍累々たるありさまです。一部の上位プレイヤーは、既存のMultiSetライブラリをググって貼り付けたり、自分で用意したりして、対応しているようです。
しかし、これには、以下の課題があります。
- 平衡二分木を使ったMultiSetの実装は、複雑長大であり、応用含めた使い方がわかりにくい。
- heapqやセグメント木(BIT含む)を使った実装は、制限条件が多すぎて、競プロ問題に適用しにくい。
よって、この記事では、BITを使った簡単実装でありながら、競プロ問題への応用性を備えたMultiSetを提供します。具体的には以下の特徴を持ちます。
- 容易に理解可能な約50行の軽量実装。(別途、BITの実装が必要です)
- Python使いにも理解しやすい、Python set、list、bisectに似たメソッドサポート。
- 簡易MultiSetの弱点であった「座標圧縮」を標準サポート。
- レアな問題で必要になる、複数の同じ数を一度にadd、remove、countする関数をサポート。さらにこれらの高速化のため、BITに加えてリストも内部利用。
- 原理的に、インタラクティブ問題以外のMultiSet問題に、全て対応可能。
1. メソッドの説明
mset = MultiSet(n=0, compress=[], multi=True)
- MultiSetを作成します。
- 各引数は省略可です。
n
は引数名を指定不要です。n
かcompress
のどちらかは指定してください。- 例)
mset = MultiSet(10) # 0〜9の値を取りうるマルチセットを作成
- 例)
mset = MultiSet(compress=A) # Aで座圧するマルチセットを作成
- 例)
mset = MultiSet(compress=A, multi=False) # Aで座圧するOrderedSetを作成
- 例)
-
compress
を指定しない場合は、MultiSetの取りうる値は0〜n-1
に限定されます。この場合、n
は、MultiSetの取りうる値の種類の個数になります。 -
compress
は座標圧縮する元となる整数リストです。compress
を指定すると、n
は無視されて、MultiSetの取りうる値はcompress
の値すべてになります。 -
multi
がTrue
(デフォルト)の場合、MultiSet動作(同じ値を複数保持可能)になります。False
の場合、OrderedSet動作(順序は持つが、Pythonのsetと同様に同じ値は同一視する)になります。
mset.add(x, n=1)
- MultiSetに値
x
をn
個追加します。n
は省略すると1
となります。OrderedSet動作の場合、n=1
以外を指定するとエラーになります。 - 座圧している場合は、
x
は座圧要素のどれかである必要があります。それ以外のx
を指定すると例外となります。 - $O(\log n)$ で動作します。
mset.remove(x, n=1)
- MultiSetから値
x
をn
個削除します。n
は省略すると1
となります。OrderedSet動作の場合、n=1
以外を指定するとエラーになります。 -
x
はMultiSetに現在n
個以上含まれる要素である必要があります。それ以外のx
とn
を指定すると例外となります。 - $O(\log n)$ で動作します。
print(mset)
- MultiSetを表示します。
- $O(n \log n)$ で動作します。
list(mset)
- MultiSetをソート済リストに変換します。
- $O(n \log n)$ で動作します。
set(mset)
- MultiSetを集合型に変換します。同じ要素は同一視されます。
- $O(n \log n)$ で動作します。
len(mset)
- MultiSetの要素数を返します。
- $O(1)$ で動作します。
mset.count(x)
- MultiSetの要素
x
の要素数を返します。 -
x
は任意の整数を指定可能です。 - $O(1)$ で動作します。
mset[idx]
- MultiSetの
idx
番目の要素を返します。 -
idx
は、MultiSetの要素数をN
とした場合、-N
以上N - 1
未満での整数である必要があります。それ以外のidx
を指定すると例外となります。Pythonのリストと同様に、マイナスのidx
は、末尾からの逆順となります。 - $O(\log n)$ で動作します。
x in mset
- MultiSetにxが含まれているか判定します。
-
x
は任意の整数を指定可能です。 - $O(1)$ で動作します。
x not in mset
- MultiSetにxが含まれていないことを判定します。
-
x
は任意の整数を指定可能です。 -
- $O(1)$ で動作します。
mset.bisect_left(x)
- MultiSetをソート済リストとみなして
bisect_left
(x
以上の最小の値のインデックス)を返します。 -
x
は任意の整数を指定可能です。 - $O(\log n)$ で動作します。
mset.bisect_right(x)
- MultiSetをソート済リストとみなして
bisect_right
(x
を超える最小の値のインデックス)を返します。 -
x
は任意の整数を指定可能です。 - $O(\log n)$ で動作します。
2. 利用例
A = [0, 1, 2, 3, 4, 100, 10000]
mset = MultiSet(compress=A)
mset.add(2)
mset.add(4)
mset.add(100, 2)
mset.add(10000)
print(mset)
# MultiSet {2, 4, 100, 100, 10000}
print(5 in mset)
# False
print(100 in mset)
# True
print(len(mset))
# 5
print(mset.count(100))
# 2
print(mset.bisect_left(6))
# 2
print(mset.bisect_left(100))
# 2
print(mset.bisect_right(100))
# 4
3. 応用
実際に、直近のMultiSet問題を解いてみます。
3.1. ABC217-D問題
クエリーによって少しずつMultiSetが成長する中で、xを含む区間を求める問題です。
最初に取りうるxの値全てを抽出して座圧します。
PyPy3でのACタイムは約700msです。(実行時間制限: 2秒)
L, Q = map(int, input().split())
query = [list(map(int, input().split())) for _ in range(Q)]
# c==1の時の値xを全て抽出して、両端を加えて、OrderedSetを座圧初期化
oset = MultiSet(compress=[0] + [x for c, x in query if c == 1] + [L], multi=False)
# 両端を設定
oset.add(0)
oset.add(L)
# クエリーを1つずつ処理する
for c, x in query:
if c == 1:
oset.add(x)
else:
i = oset.bisect_left(x)
print(oset[i] - oset[i - 1])
3.2. ABC241-D問題
クエリーによって少しずつMultiSetが成長する中で、x以上(以下)k番目の値を求める問題です。最初に取りうるxの値全てを抽出して座圧します。
PyPy3でのACタイムは約600msです。(実行時間制限: 2秒)
Q = int(input())
query = [list(map(int, input().split())) for _ in range(Q)]
# 値xを全て抽出して、MultiSetを座圧初期化
X = [query[i][1] for i in range(Q) if query[i][0] == 1]
mset = MultiSet(compress=X)
# クエリーを1つずつ処理する
for t, x, *params in query:
if params:
k = params[0]
if t == 1:
mset.add(x)
elif t == 2:
i = mset.bisect_right(x) - 1 # x以下の最大index
# k番目が存在すれば出力、しなければ-1
print(mset[i - k + 1] if i - k + 1 >= 0 else -1)
else:
i = mset.bisect_left(x) # x以上の最小index
# k番目が存在すれば出力、しなければ-1
print(mset[i + k - 1] if i + k - 1 < len(mset) else -1)
3.3. ABC245-E問題
応用問題です。縦サイズでチョコと箱をソートしておき、縦サイズでの比較で箱に入る可能性があるチョコをMultiSetに貯めておき、もっともギリギリで入るチョコ(MultiSetの中で箱の横サイズ以下の最大のチョコ)を選択して取り出し、箱に入れていきます。予め、MultiSetに入れて比較するチョコの横サイズBで座圧しておきます。
PyPy3でのACタイムは約1500msです。(実行時間制限: 4秒)
N, M = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = list(map(int, input().split()))
D = list(map(int, input().split()))
# A, C でソート
np_T = lambda x: [list(x) for x in zip(*x)]
A, B = np_T(sorted([(a, b) for a, b in zip(A, B)]))
C, D = np_T(sorted([(c, d) for c, d in zip(C, D)]))
# Bで座圧しておく
mset = MultiSet(compress=B)
n = m = 0
while m < M:
if n < N and A[n] <= C[m]: # チョコが箱に入れられる -> Bをmsetに入れておく
mset.add(B[n])
n += 1
else: # 箱に入る可能性ありは全て見た -> msetから最適なものを取り出す
i = mset.bisect_right(D[m]) # 箱に入る最大のチョコを探す
if i > 0: # 対象が存在したら削除(箱に入れる)
mset.remove(mset[i - 1])
m += 1 # 次の箱を見る(対象が存在しないなら、この箱は無視)
# Bが全て箱に入ったらYes
print('Yes' if n == N and len(mset) == 0 else 'No')
3.4. ARC140-B問題
ARCにおいても、Multisetを使う問題が出ています。この問題は、Multisetを使わなくても天才的解法が存在しますが、Multisetを使うと難易度が下がります。
PyPy3でのACタイムは約200msです。(実行時間制限: 2秒)
いわゆる「アルゴリズムで殴る」ことができます。
N = int(input())
S = input()
# 'ARC'のネスト数を数えて、msetに入れる
nest = 0
mset = MultiSet(N + 1)
for i in range(N):
if S[i - 2: i + 1] == 'ARC':
nest = 1
elif nest > 0 and S[i] == 'C' and i - nest * 2 - 2 >= 0 and S[i - nest * 2 - 2] == 'A':
nest += 1
elif nest > 0:
mset.add(nest)
nest = 0
if nest > 0:
mset.add(nest)
ans = 0
odd = True
while len(mset) > 0:
if odd: # 最大のネスト数を1つ減らす
x = mset[-1]
mset.remove(x)
if x > 1:
mset.add(x - 1)
else: # 最小のネスト数を削除する
x = mset[0]
mset.remove(x)
ans += 1
odd = not odd
print(ans)
3.5. ABC253-C問題
ついにABC-C問題でmultisetを要求される、暗黒時代になりました。しかも、count
やn
個まとめてremove
する機能を使う、レア問題です。
Python向けmultiset実装の多くが、この問題には敗れ去ったものと思います。筆者も、本記事の実装(以前の版)を使いTLEしました。しかしながら、軽量実装の利点を活かして、本問題取り組み中にremoveの複数要素対応を追加実装して、ACにこぎつけました。
PyPy3でのACタイムは約700msです。(実行時間制限: 2秒)
Q = int(input())
query = [list(map(int, input().split())) for _ in range(Q)]
# 座圧候補をリストアップ
xlist = []
for t in query:
if t[0] == 1 or t[0] == 2:
xlist.append(t[1])
# 座圧してmultisetを作成
mset = MultiSet(compress=xlist)
# クエリーを素直に解く
for t, *params in query:
if t == 1:
x, = params
mset.add(x)
elif t == 2:
x, c = params
count_x = mset.count(x)
mset.remove(x, min(c, count_x))
else:
print(mset[-1] - mset[0])
4. 実装
MultiSetの実装は以下です。
2本のセグメント木を使って実装する例をよく見かけますが、BIT1本のみで実装できたため、Python・PyPy3としては、比較的高速軽量に動作します。
なお、別途、二分探索をサポートしているBITの実装が必要です。こちらを参照してclass BIT
を貼り付けてください。
import bisect
class MultiSet:
# n: サイズ、compress: 座圧対象list-likeを指定(nは無効)
# multi: マルチセットか通常のOrderedSetか
def __init__(self, n=0, *, compress=[], multi=True):
self.multi = multi
self.inv_compress = sorted(set(compress)) if len(compress) > 0 else [i for i in range(n)]
self.compress = {k: v for v, k in enumerate(self.inv_compress)}
self.counter_all = 0
self.counter = [0] * len(self.inv_compress)
self.bit = BIT(len(self.inv_compress))
def add(self, x, n=1): # O(log n)
if not self.multi and n != 1: raise KeyError(n)
x = self.compress[x]
count = self.counter[x]
if count == 0 or self.multi: # multiなら複数カウントできる
self.bit.add(x + 1, n)
self.counter_all += n
self.counter[x] += n
def remove(self, x, n=1): # O(log n)
if not self.multi and n != 1: raise KeyError(n)
x = self.compress[x]
count = self.bit.get(x + 1)
if count < n: raise KeyError(x)
self.bit.add(x + 1, -n)
self.counter_all -= n
self.counter[x] -= n
def __repr__(self):
return f'MultiSet {{{(", ".join(map(str, list(self))))}}}'
def __len__(self): # oprator len: O(1)
return self.counter_all
def count(self, x): # O(1)
return self.counter[self.compress[x]] if x in self.compress else 0
def __getitem__(self, i): # operator []: O(log n)
if i < 0: i += len(self)
x = self.bit.lower_bound(i + 1)
if x > self.bit.n: raise IndexError('list index out of range')
return self.inv_compress[x - 1]
def __contains__(self, x): # operator in: O(1)
return self.count(x) > 0
def bisect_left(self, x): # O(log n)
return self.bit.sum(bisect.bisect_left(self.inv_compress, x))
def bisect_right(self, x): # O(log n)
return self.bit.sum(bisect.bisect_right(self.inv_compress, x))