競プロ強くなりたいので、このサイト(AtCoder 版!蟻本 (初級編))を参考に、AtCoderの問題を解いてアルゴリズムの基礎固めしています。今回解いたのは、この問題(の中のC-ダーツ)です。競プロやるうえで基本となる二分探索(計算量O(logN))ですが、どこで使うか判断するのがなかなか難しいです。。
問題概要
4本の矢を投げて、点数$P_1, P_2, ..., P_N$ の中で当たった点数の合計のうち、$M$ を超えないものの最大値を求めます。矢を「投げない」という選択もOKです。
$S \leq M$
$1 \leq N \leq 1000$
$1 \leq M \leq 2 \times 10^8$
解法
※ここに解説は載っていますが、自分の理解のためにまとめます。
一番最初に思いつくのは、$(N+1)^4$ 通りの全てのパターンを調べることです(「+1」は、「投げない」場合としてP=0の矢を加えている分です)。つまり全探索ですね。4重ループを書くだけで実装できます。しかしこれだと、計算量が$O(N^4)$ となってしまい、$N = 100~200$ 程度でないとACできません。
そこで一工夫して、二分探索を使用すると、$O(N^2logN)$ 時間で解くことができます。解説に載っている、解法3のやり方です。
矢を「4回投げる」のではなく、「2回投げるのを2回行う」と考えます。矢を2回投げたとき、得られる点数は$(N+1)^2$ 通り以下です。この結果を、$Q_1, Q_2, ...., Q_r$ として、初めにソートしておきましょう。計算量は$O(N^2logN)$ です。
→「つまり、点数がr通りあった時、矢を2回投げた時の合計がM以下」と言い換えることができます。
最初の2本の矢の合計を$Q_i$ であったとすると、残りの2本の矢の合計$(Q_j)$ の最適解は、
$Q_i + Q_j \leq M$
を満たす$j$ のうち、$Q_j$ が最も大きくなるものを求めればよいです。このような$j$ は二分探索により、$O(N^2logN)$ 時間で求めることができます。
下準備の時間も合わせて、全体で$O(N^2logN)$ 時間で求めることができます。
pythonによる解答例が以下です。ライブラリbisectを使うと簡単に二分探索できます。
import bisect
N, M = map(int, input().split())
p_list = [int(input()) for _ in range(N)]
p_list.append(0)
p_list.sort()
q_list = []
# Qiのリストを作成
for i in range(N+1):
for j in range(i, N+1): # 重複がないようにしている・・・★
q = p_list[i] + p_list[j]
if q <= M:
q_list.append(q)
q_list.sort()
ans = 0
for q1 in q_list:
# 二分探索(合計がMの時に正しく計算するため、rightにする必要あり)
insert_index = bisect.bisect_right(q_list, M-q1)
tmp_ans = q1 + q_list[insert_index-1]
ans = max(ans, tmp_ans)
print(ans)
ちなみに、コード中の★の部分を、
for i in range(N+1):
for j in range(N+1):
q = p_list[i] + p_list[j]
if q <= M:
q_list.append(q)
として計算して後でソートしようとしたらMLEでした。メモリって大事。