12
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【Python】K番目に大きい組み合わせを取るための3つの方法【heapq 優先度付きキュー】

Last updated at Posted at 2019-04-12

今回は「3つのパラメータによって変化する数値dについて、大きいほうからK番目までを取り出してくる」という問題について3つの解法を考察する。

良問ぞろいのAtcoder Beginner Contest D問題から抜粋する。問題は以下である。
ABC123 D問題
###問題

AtCoder 洋菓子店は数字の形をしたキャンドルがついたケーキを販売しています。
ここには 1, 2, 3 の形をしたキャンドルがついたケーキがそれぞれ X 種類、Y 種類、Z 種類あります。
それぞれのケーキには「美味しさ」という整数の値が以下のように割り当てられています。
・1 の形のキャンドルがついたケーキの美味しさはそれぞれA1, A2, ... , AX
・2 の形のキャンドルがついたケーキの美味しさはそれぞれB1, B2, ... , BY
・3 の形のキャンドルがついたケーキの美味しさはそれぞれC1, C2, ... , CZ
高橋君は ABC 123 を記念するために、1,2,3 の形のキャンドルがついたケーキを 1 つずつ買うことにしました。
そのようにケーキを買う方法は X × Y × Z 通りあります。
これらの選び方を 3 つのケーキの美味しさの合計が大きい順に並べたとき、1, 2, ... , K 番目の選び方でのケーキの美味しさの合計をそれぞれ出力してください。

###制約

・ 1 ≤ X ≤ 1000
・ 1 ≤ Y ≤ 1000
・ 1 ≤ Z ≤ 1000
・ 1 ≤ K ≤ min(3000, X x Y x Z)
・ 1 ≤ Ai ≤ 10 000 000 000
・ 1 ≤ Bi ≤ 10 000 000 000
・ 1 ≤ Ci ≤ 10 000 000 000
・入力中の値はすべて整数である。

いま、Aのリストをa_list、Bのリストをb_list、Cのリストをc_listとする。
例えば、それぞれ[6,4]、[5,1]、[8,3]があたえられたとする。このとき、大きい順に8番目までを出力することを考える。このとき、出力例は19,17,15,14,13,12,10,8である。

さて、まずは一番安直な方法を考える。
A,B,Cの考えるうる値すべてについてA+B+Cを計算し、それらを収めたリストをソートする。このリストのはじめからK番目までを出力すればいい。
全探索とソートにより、計算量は O(XYZ + XYZlogXYZ)である。
X,Y,Zが小さい値ならば計算時間を考慮しなくてもよいかもしれないが、最大値を考えるとX * Y * Z = 10 ** 9 であり、実行時間制限に間に合わない。

##解法1:大きいほうからK番目までのA+Bのみ考える
大きいほうからK番目を出力する問題であり、K+1番目以降は無視してもいいということになる。いま、選択した値にA,B,Cついて、A+Bが全(A+B)の中で大きい順にK+1番以降であったとする。このとき、どんなCを持ってきてもA+B+CがK番目以内に入ることはない。
解法1はこの性質を利用する。
まずA+Bを全探索し、ソートする。上位K個のA+Bをリストにおさめ、K個の(A+B)とZ個のCすべての組み合わせを計算。最後にK*Z個のA+B+Cをソートし、K番目までを出力する。
このとき計算量はO(XY+XYlogXY+KZ+KZlogKZ)である。
最大値を考えると10**7程度である。Pythonだと限界に近い実行時間である。

解法1をPythonで記述すると以下となる。

test1.py
# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]

# list top K of a+b
a_plus_b_list = []
for a in a_list:
    for b in b_list:
        a_plus_b_list.append(a+b)
a_plus_b_list.sort(reverse = True)

# list top K of a+b+c
d_list = []
for a_b in a_plus_b_list:
    for c in c_list:
        d_list.append(a_b + c)
d_list.sort(reverse = True)
for i in range(K):
    print(d_list[i])

##解法2:p x q x rがKより大きいものを無視
いま、a_list,b_list,c_listを降順にソートします。上からそれぞれp,q,r番目を選択したとする。このとき、p x q x rがKより大きい組み合わせ(A,B,C)について、A+B+Cが上位K番に入ることはない。
仮にp x q x rがKより大きいときに、A+B+Cが上位K番以内に入っていたとする。(p,q,r)の代わりにp-1,q-1,r-1番目のいずれかを使ってもK番以内に入ることができるが、これを満たす組み合わせがK個を超えてしまうことから証明できます。

まず、降順にソートされたa_list,b_list,c_listをの上から参照していき、p(番目) x q(番目) x r(番目)がK以内であればd_listにa+b+cを追加、Kより大きければbreakして次のループに移る。
最後にd_listをソートし、上位K番目までを出力する。
最大値 X,Y,Z=1000, K=3000では、(p x q x r)<= 3000 を満たす組み合わせは106,307通りある。このとき、計算量はN = 106,307としてO(N + NlogN)、10**6程度である。

解法2をPythonで記述すると以下となる。

test2.py
# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]

# sort list
a_list.sort(reverse = True)
b_list.sort(reverse = True)
c_list.sort(reverse = True)

# solve
d_list = []
for p in range(1,X+1):
    for q in range(1,Y+1):
        for r in range(1,Z+1):
            if p * q * r <= K:
                d_list.append(a_list[p-1] + b_list[q-1] + c_list[r-1])
            else:
                break
d_list.sort(reverse = True)
for i in range(K):
    print(d_list[i])

##解法3:優先度付きキュー、heapqを用いた解法
解法2のp,q,rをここでも用いる。
いま、(p,q,r)番目の(A,B,C)が上位K番以内に入っていることが保証されているとする。このとき、次にどの組み合わせを参照すればいいだろうか。a,b,cそれぞれのリストは降順でソートされているため、(p+q+r)の次に(A+B+C)が大きい組み合わせは(p+1,q,r)、(p,q+1,r)、(p,q,r+1)のいずれかである。

この前提において、優先度付きキューを用いる方法を考える。

優先度付きキューは、作成したキュー hq (type = list) 内において、一番大きな値をもつ要素を常に保持してくれる。

python では優先度付きキューとしてheapqモジュールが用意されている。

###python heapqモジュールの使い方

####heapq からheappush, heappopのインポート

from heapq import heappush, heappop

今回はheappushとheappopのみ使用する。

####キュー(空リスト)の作成

hq = []

このhqに値が追加されていく。変数名は自由だが "hq" や "Q" が一般的。

####heappushでキューに値を追加

heappush(hq, (val, 0, 0, 0))

heappushの第一引数には、追加先のキュー名、第二引数には追加したい値のタプルを入れる。第二引数のタプル内では、1番目に比較に用いたい値(val)
、2番目以降にパラメータの値を入れる。

####heappopで一番小さい値を取り出す

>>>heappop(hq)
(val,0,0,0)
>>>hq
[]

heappop(キュー名)で、キュー内の比較値が最も小さいものが取り出さされる。ここで注意すべきことは、heappopは最も小さい値をとりだすため、今回の問題のような場合には比較する(A+B+C)にマイナスをかけた値を用いる必要がある。

heappopをした後はpopされた要素は削除される。

さて、このheapqを用いて問題を解く。
はじめにhqに、(p,q,r) = (1,1,1)、つまりA+B+Cが最も大きい組み合わせを追加する。その後、以下の操作をK回繰り返す。
・heappopで一番小さい値を取り出す(-(A+B+C)を比較値に用いるため、A+B+Cが大きい順に取り出される)
・(p+1,q,r)がまだ追加されていなければ、hqに追加する
・(p,q+1,r)がまだ追加されていなければ、hqに追加する
・(p,q,r+1)がまだ追加されていなければ、hqに追加する

解法3の計算量は、a,b,cリストのソート、K回のheapqの操作でO(XlogX + YlogY + ZlogZ + KlogK)、最大値でも10**4程度とほかの解法に比べ高速である。

解法3をPythonで記述すると以下となる。

test3.py
from heapq import heappush, heappop

# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]
a_list.sort(reverse = True)
b_list.sort(reverse = True)
c_list.sort(reverse = True)

#solve
hq = []
arg_hash = {}
val = a_list[0] + b_list[0] + c_list[0]
heappush(hq, (-val,0,0,0))

for i in range(K):
    ans, a, b, c = heappop(hq)
    print(-ans)

    arg_a = (a+1, b, c)
    arg_b = (a, b+1, c)
    arg_c = (a, b, c+1)

    if a < X-1 and arg_a not in arg_hash:
        arg_hash[arg_a] = 1
        heappush(hq, (-(a_list[a+1]+b_list[b]+c_list[c]),)+arg_a)
    if b < Y-1 and arg_b not in arg_hash:
        arg_hash[arg_b] = 1
        heappush(hq, (-(a_list[a]+b_list[b+1]+c_list[c]),)+arg_b)
    if c < Z-1 and arg_c not in arg_hash:
        arg_hash[arg_c] = 1
        heappush(hq, (-(a_list[a]+b_list[b]+c_list[c+1]),)+arg_c)

arg_hashにキューに追加された組み合わせを記憶させている。

12
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?