はじめに
今回はPythonの優先度付きキューについてまとめます。
AtCoderのPython3.4.3と3.8で動作確認済みです。
優先度付きキューについて
優先度付きキュー (Priority queue) はデータ型の一つで、具体的には
- 最小値(最大値)を $O(\log{N})$で取り出す
- 要素を $O(\log{N})$ で挿入する
ことが出来ます。通常のリストだとそれぞれ $O(N)$ ですので高速です。
「リストの要素の挿入」と「最小値(最大値)を取り出す」ことを繰り返すような時に使います。
Pythonでの使い方
Pythonでは優先度付きキューは heapq として標準ライブラリに用意されています。使いたいときはimportしましょう。
各メソッドについて
頻繁に使うメソッドは3つです。
-
heapq.heapify(リスト)
でリストを優先度付きキューに変換。 -
heapq.heappop(優先度付きキュー (=リスト) )
で優先度付きキューから最小値を取り出す。 -
heapq.heappush(優先度付きキュー (=リスト) , 挿入したい要素)
で優先度付きキューに要素を挿入。
※ Pythonではheap化されたリストのクラスもリストであるためこのような書き方をしています。↓参考
import heapq
a = [1, 6, 8, 0, -1]
print(type(a)) # <class 'list'>
heapq.heapify(a)
print(type(a)) # <class 'list'>
では、各メソッドの使い方について見ていきましょう。
import heapq # heapqライブラリのimport
a = [1, 6, 8, 0, -1]
heapq.heapify(a) # リストを優先度付きキューへ
print(a)
# 出力: [-1, 0, 8, 1, 6] (優先度付きキューとなった a)
print(heapq.heappop(a)) # 最小値の取り出し
# 出力: -1 (a の最小値)
print(a)
# 出力: [0, 1, 8, 6] (最小値を取り出した後の a)
heapq.heappush(a, -2) # 要素の挿入
print(a)
# 出力: [-2, 0, 1, 8, 6] (-2 を挿入後の a)
なお、もちろんリストのheap化は行わなくてもheappop()
とheappush()
は使用できます。しかしheappop()
ではリストの先頭の要素が取り出される仕様になっているため、初回のheappop()
を行う際は気をつけましょう。最初のリストが空のリスト[]
であるような場合にはheapify
を行う必要がないので以下の通りで問題ありません。
import heapq # heapqライブラリのimport
a = [1, 6, 8, 0, -1]
# heapq.heapify(a)
print(heapq.heappop(a)) # 先頭の要素 (1) が取り出される!!!!
print(a) # [-1, 6, 8, 0]
heapq.heappush(a, -2) # 要素の挿入
print(a) # [-2, -1, 8, 0, 6]
print(heapq.heappop(a)) # 最小値の取り出し
print(a) # [-1, 0, 8, 6]
最大値の取り出し
heapqでは最小値しか取り出すことが出来ません。では最大値の時はどうするかというと、各要素に-1をかけた上で最小値を取り出していきます。以下のコードではmap
関数で各要素を-1倍していますが、実際に問題を解く際には入力時に-1した方が高速です。
import heapq
a = [1, 6, 8, 0, -1]
a = list(map(lambda x: x*(-1), a)) # 各要素を-1倍
print(a)
heapq.heapify(a)
print(heapq.heappop(a)*(-1)) # 最大値の取り出し
print(a)
[-8, -6, -1, 0, 1]
8
[-6, 0, -1, 1]
例題 (ABC141 D)
AtCoder Beginners Contest 141 D問題 Powerful Discount Tickets
この問題は 最大値を1/2することをM回繰り返す だけでよいのですが、通常のリストを使うと、最大値の選定に $O(N)$ 、それをM回のため計算量が $O(NM)$ で間に合いません。
そこでheapqの出番です。heapqでは最大値の選定に $O(\log{N})$ 、それをM回のため計算量が $O(M \log {N})$ となり間に合います。以下が実装です。
import heapq
n, m = map(int, input().split())
a = list(map(lambda x: int(x)*(-1), input().split()))
heapq.heapify(a) # aを優先度付きキューに
for _ in range(m):
tmp_min = heapq.heappop(a)
heapq.heappush(a, (-1)*(-tmp_min//2)) # 負数の剰余演算を避けるため一時的に0以上の整数にしています
print(-sum(a))
おわりに
読んでいただきありがとうございました。指摘等ありましたら是非コメントお願いします。