ある条件を満たす順列の数え上げに関して、n!では間に合わないけど2**nなら間に合いそうという時に使えるテクニックです。
全要素の集合をU、その部分集合をSとしてSに関して条件を満たす並び方の総数をf(S)とします。右(左)端に何を置くかで場合分けを行うと、要素数のひとつ少ないSの部分集合S'の話に帰着するので、帰納的に解くことができます。さらに、dpを効率よく行う方法として、集合をbitで管理するという発想があります。つまり、0110={1, 2}のように対応づけをすると、集合に整数でラベルを貼ることができます。このラベルをn(S)で表すことにすると、S⊇S'ならばn(S)≧n(S')が成立するので、nの小さい方から順に求めていくことで簡単に実装できます。
例題:ABC041D
def main():
N, M = map(int, input().split())
edge = [0]*N
for i in range(M):
x, y = map(int, input().split())
edge[x-1]|=1<<(y-1)
dp = [0]*(1<<N)
dp[0] = 1
for s in range(1, 1<<N):#集合を添字の小さい順に試す
for i in range(N):#全ての要素を考える
if ((s>>i)&1) and (not(edge[i]&s)):#i in sかつedge[i]とsが共通部分を持たない
dp[s]+=dp[s^(1<<i)]
return dp[-1]
print(main())
(関係ないですがmain()に埋め込むと100msくらい速くなりました。すごい。)
もう一つの例題です。
N = int(input())
x = list(map(int, input().split()))
a = list(map(int, input().split()))
edges = [0]*N
cost = [0]*N
for i, ai in enumerate(a):
edges[ai-1]+=1<<(i+1)
cost[ai-1]+=x[i+1]
disk = [0]*(1<<N)
dp = [float('inf')]*(1<<N)
dp[0] = 0
for s in range(1<<N):
for i in range(N):
if not s & (1 << i) and s & edges[i] == edges[i]:
nextS = s| (1 << i)
disk[nextS] = disk[s] + x[i]
dp[nextS] = min(dp[nextS], max(dp[s], disk[nextS]))
disk[nextS] -= cost[i]
print(dp[-1])
実はこれ、はじめは貰うDPで書いていたのですが、謎のWAが取れずにかなり時間を溶かしました。結論から言うと、貰うDPで書くと起こりえない場合に関しても考えてしまうため、diskの値に不都合が起こることがわかりました。配るDPの形で書くと、すでに起こりうるとわかっている状態からの遷移のみを考えられるので、そのような不都合がなくなるわけです。
上では色々なbit演算を用いています。いい機会なので、bit演算についてまとめておきます。
|: orのこと。
&: andのこと。
^: xorのこと。実質、繰り上がりのない引き算、足し算
<<: 左シフト。右に0をn個付け加える
>>: 右シフト。右からn桁を削除する
よく使うのは以下のような奴らです。
1<<N: 2**Nのこと
n>>k&1: nのk+1桁目が1かどうか(0-indexedなら右からk番目)
S&T: 集合の積