今回は「3つのパラメータによって変化する数値dについて、大きいほうからK番目までを取り出してくる」という問題について3つの解法を考察する。
良問ぞろいのAtcoder Beginner Contest D問題から抜粋する。問題は以下である。
ABC123 D問題
###問題
AtCoder 洋菓子店は数字の形をしたキャンドルがついたケーキを販売しています。
ここには 1, 2, 3 の形をしたキャンドルがついたケーキがそれぞれ X 種類、Y 種類、Z 種類あります。
それぞれのケーキには「美味しさ」という整数の値が以下のように割り当てられています。
・1 の形のキャンドルがついたケーキの美味しさはそれぞれA1, A2, ... , AX
・2 の形のキャンドルがついたケーキの美味しさはそれぞれB1, B2, ... , BY
・3 の形のキャンドルがついたケーキの美味しさはそれぞれC1, C2, ... , CZ
高橋君は ABC 123 を記念するために、1,2,3 の形のキャンドルがついたケーキを 1 つずつ買うことにしました。
そのようにケーキを買う方法は X × Y × Z 通りあります。
これらの選び方を 3 つのケーキの美味しさの合計が大きい順に並べたとき、1, 2, ... , K 番目の選び方でのケーキの美味しさの合計をそれぞれ出力してください。
###制約
・ 1 ≤ X ≤ 1000
・ 1 ≤ Y ≤ 1000
・ 1 ≤ Z ≤ 1000
・ 1 ≤ K ≤ min(3000, X x Y x Z)
・ 1 ≤ Ai ≤ 10 000 000 000
・ 1 ≤ Bi ≤ 10 000 000 000
・ 1 ≤ Ci ≤ 10 000 000 000
・入力中の値はすべて整数である。
いま、Aのリストをa_list、Bのリストをb_list、Cのリストをc_listとする。
例えば、それぞれ[6,4]、[5,1]、[8,3]があたえられたとする。このとき、大きい順に8番目までを出力することを考える。このとき、出力例は19,17,15,14,13,12,10,8である。
さて、まずは一番安直な方法を考える。
A,B,Cの考えるうる値すべてについてA+B+Cを計算し、それらを収めたリストをソートする。このリストのはじめからK番目までを出力すればいい。
全探索とソートにより、計算量は O(XYZ + XYZlogXYZ)である。
X,Y,Zが小さい値ならば計算時間を考慮しなくてもよいかもしれないが、最大値を考えるとX * Y * Z = 10 ** 9 であり、実行時間制限に間に合わない。
##解法1:大きいほうからK番目までのA+Bのみ考える
大きいほうからK番目を出力する問題であり、K+1番目以降は無視してもいいということになる。いま、選択した値にA,B,Cついて、A+Bが全(A+B)の中で大きい順にK+1番以降であったとする。このとき、どんなCを持ってきてもA+B+CがK番目以内に入ることはない。
解法1はこの性質を利用する。
まずA+Bを全探索し、ソートする。上位K個のA+Bをリストにおさめ、K個の(A+B)とZ個のCすべての組み合わせを計算。最後にK*Z個のA+B+Cをソートし、K番目までを出力する。
このとき計算量はO(XY+XYlogXY+KZ+KZlogKZ)である。
最大値を考えると10**7程度である。Pythonだと限界に近い実行時間である。
解法1をPythonで記述すると以下となる。
# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]
# list top K of a+b
a_plus_b_list = []
for a in a_list:
for b in b_list:
a_plus_b_list.append(a+b)
a_plus_b_list.sort(reverse = True)
# list top K of a+b+c
d_list = []
for a_b in a_plus_b_list:
for c in c_list:
d_list.append(a_b + c)
d_list.sort(reverse = True)
for i in range(K):
print(d_list[i])
##解法2:p x q x rがKより大きいものを無視
いま、a_list,b_list,c_listを降順にソートします。上からそれぞれp,q,r番目を選択したとする。このとき、p x q x rがKより大きい組み合わせ(A,B,C)について、A+B+Cが上位K番に入ることはない。
仮にp x q x rがKより大きいときに、A+B+Cが上位K番以内に入っていたとする。(p,q,r)の代わりにp-1,q-1,r-1番目のいずれかを使ってもK番以内に入ることができるが、これを満たす組み合わせがK個を超えてしまうことから証明できます。
まず、降順にソートされたa_list,b_list,c_listをの上から参照していき、p(番目) x q(番目) x r(番目)がK以内であればd_listにa+b+cを追加、Kより大きければbreakして次のループに移る。
最後にd_listをソートし、上位K番目までを出力する。
最大値 X,Y,Z=1000, K=3000では、(p x q x r)<= 3000 を満たす組み合わせは106,307通りある。このとき、計算量はN = 106,307としてO(N + NlogN)、10**6程度である。
解法2をPythonで記述すると以下となる。
# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]
# sort list
a_list.sort(reverse = True)
b_list.sort(reverse = True)
c_list.sort(reverse = True)
# solve
d_list = []
for p in range(1,X+1):
for q in range(1,Y+1):
for r in range(1,Z+1):
if p * q * r <= K:
d_list.append(a_list[p-1] + b_list[q-1] + c_list[r-1])
else:
break
d_list.sort(reverse = True)
for i in range(K):
print(d_list[i])
##解法3:優先度付きキュー、heapqを用いた解法
解法2のp,q,rをここでも用いる。
いま、(p,q,r)番目の(A,B,C)が上位K番以内に入っていることが保証されているとする。このとき、次にどの組み合わせを参照すればいいだろうか。a,b,cそれぞれのリストは降順でソートされているため、(p+q+r)の次に(A+B+C)が大きい組み合わせは(p+1,q,r)、(p,q+1,r)、(p,q,r+1)のいずれかである。
この前提において、優先度付きキューを用いる方法を考える。
優先度付きキューは、作成したキュー hq (type = list) 内において、一番大きな値をもつ要素を常に保持してくれる。
python では優先度付きキューとしてheapqモジュールが用意されている。
###python heapqモジュールの使い方
####heapq からheappush, heappopのインポート
from heapq import heappush, heappop
今回はheappushとheappopのみ使用する。
####キュー(空リスト)の作成
hq = []
このhqに値が追加されていく。変数名は自由だが "hq" や "Q" が一般的。
####heappushでキューに値を追加
heappush(hq, (val, 0, 0, 0))
heappushの第一引数には、追加先のキュー名、第二引数には追加したい値のタプルを入れる。第二引数のタプル内では、1番目に比較に用いたい値(val)
、2番目以降にパラメータの値を入れる。
####heappopで一番小さい値を取り出す
>>>heappop(hq)
(val,0,0,0)
>>>hq
[]
heappop(キュー名)で、キュー内の比較値が最も小さいものが取り出さされる。ここで注意すべきことは、heappopは最も小さい値をとりだすため、今回の問題のような場合には比較する(A+B+C)にマイナスをかけた値を用いる必要がある。
heappopをした後はpopされた要素は削除される。
さて、このheapqを用いて問題を解く。
はじめにhqに、(p,q,r) = (1,1,1)、つまりA+B+Cが最も大きい組み合わせを追加する。その後、以下の操作をK回繰り返す。
・heappopで一番小さい値を取り出す(-(A+B+C)を比較値に用いるため、A+B+Cが大きい順に取り出される)
・(p+1,q,r)がまだ追加されていなければ、hqに追加する
・(p,q+1,r)がまだ追加されていなければ、hqに追加する
・(p,q,r+1)がまだ追加されていなければ、hqに追加する
解法3の計算量は、a,b,cリストのソート、K回のheapqの操作でO(XlogX + YlogY + ZlogZ + KlogK)、最大値でも10**4程度とほかの解法に比べ高速である。
解法3をPythonで記述すると以下となる。
from heapq import heappush, heappop
# input
X,Y,Z,K = map(int, input().split())
a_list = [int(i) for i in input().split()]
b_list = [int(i) for i in input().split()]
c_list = [int(i) for i in input().split()]
a_list.sort(reverse = True)
b_list.sort(reverse = True)
c_list.sort(reverse = True)
#solve
hq = []
arg_hash = {}
val = a_list[0] + b_list[0] + c_list[0]
heappush(hq, (-val,0,0,0))
for i in range(K):
ans, a, b, c = heappop(hq)
print(-ans)
arg_a = (a+1, b, c)
arg_b = (a, b+1, c)
arg_c = (a, b, c+1)
if a < X-1 and arg_a not in arg_hash:
arg_hash[arg_a] = 1
heappush(hq, (-(a_list[a+1]+b_list[b]+c_list[c]),)+arg_a)
if b < Y-1 and arg_b not in arg_hash:
arg_hash[arg_b] = 1
heappush(hq, (-(a_list[a]+b_list[b+1]+c_list[c]),)+arg_b)
if c < Z-1 and arg_c not in arg_hash:
arg_hash[arg_c] = 1
heappush(hq, (-(a_list[a]+b_list[b]+c_list[c+1]),)+arg_c)
arg_hashにキューに追加された組み合わせを記憶させている。