LoginSignup
2
1

More than 3 years have passed since last update.

[Python] 01ナップサック問題 ABC032D

Last updated at Posted at 2021-01-29

ABC032D

01ナップサック問題

  • N個の荷物があり、$i(1\le i\le N)$ 番目の荷物には価値 $v_i$ と重さ $w_i$ が割り当てられている。
  • 許容重量Wのナップサックが1つある。
  • 重さの和がW以下となるように荷物の集合を選びナップサックに詰め込むとき、価値の和の最大値を求めよ。ただし、同じ荷物は一度しか選ぶことができない。

次の3パターンのデータセットが与えられる。

  • ① $1\le N\le 30, 1\le W \le 10^9, 1\le v_i \le 10^9, 1\le w_i \le 10^9$
  • ② $1\le N\le 200, 1\le W \le 10^9, 1\le v_i \le 10^9, 1\le w_i \le 1000$
  • ③ $1\le N\le 200, 1\le W \le 10^9, 1\le v_i \le 1000, 1\le w_i \le 10^9$

時間計算量と空間計算量の制約に対応するために、それぞれのパターンに応じた計算方法を採る必要がある。

➀ N≦30

全列挙 $O(2^N)$ ではTLEになる。
動的計画法は、空間計算量が膨大でMLEになる。
枝刈りで行けそうだが、半分全列挙 $O(2^\frac{N}{2}\log {2^\frac{N}{2}})$ で確実に行う。
高速に探索できるように、最大値を実現できる可能性のある [重み,価値] のみを列挙する。

➁ w≦1000

空間計算量に配慮し、動的計画法を次のように設計する。一般的なナップサック問題の解法となる。
時間計算量は $O(N^2 W_{max})$ で空間計算量は $O(NW_{max})$ となる。

$dp[i][j]$の定義:
$i$ 番目までの荷物を選んで、重みの和 $j$ 以下を満たす、価値の和の最大値

dp初期条件:

dp[0..N][0..W]=0

dp漸化式の定義:

dp[i][j]=max(dp[i-1][j], dp[i-1][j-w_i]+v_i) \hspace{15pt}(j-w_i\ge 0)

求める解:

dp[N][W]

➂ v≦1000

空間計算量に配慮し、動的計画法を次のように設計する。
時間計算量は $O(N^2 V_{max})$ で空間計算量は $O(NV_{max})$ となる。

$dp[i][j]$の定義:
$i$ 番目までの荷物を選んで、価値の和 $j$ を達成する、重みの和の最小値

dp初期条件:

dp[0][0]=0,  dp[0..N][0..V_{max}]=\infty

dp漸化式の定義:

dp[i][j]=min(dp[i-1][j], dp[i][j-v_i]+w_i) \hspace{15pt}(j-v_i\ge 0)

求める解:

max\{\ j\ |\ dp[N][j]\le W\ \}

PyPy3ならACするが、Python3ではTLEとなる。

サンプルコード
N,W = map(int,input().split())
vw = [list(map(int,input().split())) for _ in [0]*N]

vM,wM = 0,0
for v,w in vw:
    vM = max(vM,v)
    wM = max(wM,w)

if wM <= 1000: # ➁
    wS = sum(w for v,w in vw)
    if wS <= W:
        ans = sum(v for v,w in vw)
    else:
        dp = [[0]*(W+1) for _ in range(N+1)]
        for i,vw2 in enumerate(vw,1): # イテラブルオブジェクトとインデックスを取得
            v,w = vw2
            for j in range(W+1):
                dp[i][j] = dp[i-1][j]
                if j>=w : dp[i][j] = max(dp[i][j],dp[i-1][j-w] + v)
        ans = dp[-1][W]
    print(ans)
elif vM <= 1000: # ➂
    V = sum(v for v,w in vw)
    dp = [[W+1]*(V+1) for _ in range(N+1)]
    dp[0][0] = 0
    for i,vw2 in enumerate(vw,1):
        v,w = vw2
        for j in range(V+1):
            dp[i][j] = dp[i-1][j]
            if j>=v : dp[i][j] = min(dp[i][j],dp[i-1][j-v]+w)
    print(max(i for i,w in enumerate(dp[-1]) if w<=W))
elif N <= 30: # ➀
  w_max = W
  V,W = zip(*vw) # Weightでソートしたい
  # N <= 30
  # 半分全列挙
  left = [(0,0)] # weight, value
  right = [(0,0)]
  for i in range(N//2):
    left += [(x+W[i],y+V[i]) for x,y in left]
  for i in range(N//2,N):
    right += [(x+W[i],y+V[i]) for x,y in right]
  left.sort() # 重さ順
  right.sort()
  def remove_worthless(li):
    temp = []
    current_value = -1
    for w,v in li:
      if w > w_max:
        break
      # wでの最大値vをリストする
      if v > current_value:
        current_value = v
        temp.append((w,v))
    return temp
  left = remove_worthless(left)
  right = remove_worthless(right)
  INF = 10**18
  right.append((INF,0))
  # double pointer
  j = 0
  x = 0
  for wL,vL in left[::-1]: # leftを逆順
    wR_max = w_max-wL
    while right[j+1][0] <= wR_max:
      j += 1
    vLR = vL + right[j][1]
    if x < vLR:
      x = vLR
  print(x)

次はNumpy version

サンプルコード
N,w_max = map(int,input().split())
VW = [[int(x) for x in input().split()] for _ in range(N)]
V,W = zip(*VW)

def case_1():
  # N <= 30
  # 半分全列挙
  left = [(0,0)] # weight, value
  right = [(0,0)]
  for i in range(N//2):
    left += [(x+W[i],y+V[i]) for x,y in left]
  for i in range(N//2,N):
    right += [(x+W[i],y+V[i]) for x,y in right]
  left.sort() # 重さ順
  right.sort()
  def remove_worthless(li):
    temp = []
    current_value = -1
    for w,v in li:
      if w > w_max:
        break
      # valueを更新したものに制限
      if v > current_value:
        current_value = v
        temp.append((w,v))
    return temp
  left = remove_worthless(left)
  right = remove_worthless(right)
  INF = 10**18
  right.append((INF,0))
  # double pointer
  j = 0
  x = 0
  for wL,vL in left[::-1]:
    wR_max = w_max-wL
    while right[j+1][0] <= wR_max:
      j += 1
    vLR = vL + right[j][1]
    if x < vLR:
      x = vLR
  return x

def case_2():
  import numpy as np
  L = N*1000+1
  dp = np.zeros(L,dtype=np.int64) # 総重量、最大価値
  for v,w in VW:
    dp[w:] = np.maximum(dp[w:], dp[:-w] + v)
  return dp[:w_max+1].max()

def case_3():
  import numpy as np
  L = N*1000+1
  dp = np.zeros(L,dtype=np.int64) # 総価値、最小重量
  dp[1:] = 10**18
  for v,w in VW:
    dp[v:] = np.minimum(dp[v:], dp[:-v] + w)
  possible_value = (dp <= w_max).nonzero()[0]
  return possible_value.max()

if N <= 30:
  print(case_1())
elif max(W) <= 1000:
  print(case_2())
else:
  print(case_3())
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1