概要
AtcoderのABC219のD問題が(おそらく)あと一歩のところで解けず、非常に悔しい思いをしたので、
忘れないよう記事にまとめました。
問題文
N 種類の弁当が、それぞれ 1 個ずつ売られています。
i=1,2,…,N について、i 種類目の弁当には Ai個のたこ焼きと Bi個のたい焼きが入っています。
高橋君は、 X 個以上のたこ焼きと Y 個以上のたい焼きを食べたいです。
高橋君がいくつかの弁当を選んで買うことで、 X 個以上のたこ焼きと Y 個以上のたい焼きを手に入れることが可能かどうか判定して下さい。また、可能な場合はそのために高橋君が購入しなければならない弁当の個数の最小値を求めて下さい。
各種類の弁当は 1 個しか売られていないため、同じ種類の弁当を 2 個以上購入することは出来ないことに注意して下さい。
制約
$1≤N≤300$
$1≤X,Y≤300$
$1≤Ai,Bi≤300$
入力はすべて整数
解き方の方針
ナップザック問題を解いたことのある人なら、同様の方法、少なくともDPで解くというのは、思い浮かべるのではないでしょうか?
自分も、DPで解くことはすぐに考えたのですが、三次元の配列でi, x, yを用いた状態の定め方がいまいちピンとぴんと来ませんでした。
そこで、たこ焼きをx個、たい焼きをy個、手に入れるときのお弁当の最小の個数で考えると、二次元の配列で収まり、自分でも実装できそうなイメージでした。
具体的な状態は以下です。
たこ焼きをx個、たい焼きをy個持っているとき、
$dp[x][y] = お弁当の個数$
i種類目のお弁当を買うとき、
$dp[x+a][y+b] = dp[x][y] + 1$
ここで、厄介になってくるのが、x+a、y+bの範囲ですが、弁当全てを購入した時のたい焼き、たこ焼きの最大の個数は、
NX, NYとなり、これに対して全ての計算を行うと、$O(N^2*XY)$で10^9のオーダーとなってしまいます。
そこで、今回はX, Yより多くなるときは、切り捨てて考えることで、
x <= X, y <= Yが担保できます。
実装は以下です。
N = int(input())
x, y = map(int, input().split())
# dpの配列を準備。x, yを超える時は、切り捨てて考えるので、配列の大きさも以下で十分。値は、最大値+1で初期化。
dp = [[N+1]*(y+1) for _ in range(x+1)]
# たこ焼き0個、たい焼き0個の時は弁当の数も0
dp[0][0] = 0
bento = []
for _ in range(N):
a, b = map(int, input().split())
bento.append([a, b])
# i+aがxより大きい時は、xに変換。j+bも同様にyに変換。
# お弁当nを買った時、全てのdpに反映させるため、x, yでそれぞれfor文を回している。
for n in range(N):
a, b = bento[n]
for i in range(x+1):
for j in range(y+1):
k = min(i+a, x)
h = min(j+b, y)
dp[k][h] = min(dp[k][h], dp[i][j]+1)
# 答えは、x個、y個(以上)の時なので、dp[x][y]を出力。ただし、Nより大きい時は、初期値のまま(全部買っても成り立たない時)なので-1に変換
ans = dp[x][y]
if ans > N:
ans = -1
print(ans)
上記で、見事AC…には、ならず、結局時間いっぱいまでACに出来ませんでした。
皆さんは、なぜ通らないかお気づきでしょうか?
正解は、
for i in range(x+1):
for j in range(y+1):
k = min(i+a, x)
h = min(j+b, y)
dp[k][h] = min(dp[k][h], dp[i][j]+1)
この部分で、iとjをそれぞれ小さい順に行なっており、この実装では、
同じ弁当を複数回買った時の状態も含まれてしまうためです。
小さい順のイメージ:
たこ焼き1個、たい焼き2個のお弁当を買った時、
dp[1][2] = 1
dp[2][4] = 2
dp[3][6] = 3 ...
これを防ぐためには、iとjを大きい順に実装すれば、重複を防ぐことが出来ます。
大きい順のイメージ:
dp[1][2] = 1, dp[2][4] = 3, dp[3][6] = 7
たこ焼き1個、たい焼き2個のお弁当を買う時、
dp[3][6] = dp[2][4] + 1 = 4
dp[2][4] = dp[1][2] + 1 = 2
というわけで、ACになる実装は、以下になります。
N = int(input())
x, y = map(int, input().split())
dp = [[N+1]*(y+1) for _ in range(x+1)]
dp[0][0] = 0
bento = []
for _ in range(N):
a, b = map(int, input().split())
bento.append([a, b])
for n in range(N):
a, b = bento[n]
for i in range(x+1)[::-1]:
for j in range(y+1)[::-1]:
k = min(i+a, x)
h = min(j+b, y)
dp[k][h] = min(dp[k][h], dp[i][j]+1)
ans = dp[x][y]
if ans > N:
ans = -1
print(ans)
最後に
forの後に[::-1]をつければ良かったんですね。。。。
皆さんもDPの問題で謎のWAに行き詰まった時は、
- 逆から加算してみる
- 制約に触れていないかを考える(今回でいえば、同じものを含まないなど)
などをぜひ考慮してみてください。
ここまでお読みいただき、ありがとうございました。