LoginSignup
1
0

More than 1 year has passed since last update.

全方位木dpをpythonで

Last updated at Posted at 2021-07-29

全方位木dpとは

ある頂点についての何かしらの通り数はDFSによる木dpで$O(N)$で計算できるとき、その通り数を全頂点について求めるdpです。
愚直にDFSすると当然$O(N^2)$となりますが、1つの頂点に対するDFSの結果を再利用することにより、$O(N)$の計算量を達成します。

全方位木dpのアルゴリズム自体の解説はわかりやすいものが大量にネットに転がっているのでここでは割愛させていただきます。(説明能力の欠如)

概念は理解できたものの自力で実装ができず、他の方々のコードを読んでも全く理解できずにドはまりしたので、コメント多めのものを以下に記載します。

解いた問題は↓です。
Educational DP Contest / DP まとめコンテスト V - Subtree

用いている関数とその用途は以下の通りです。
・dfs1 : 頂点0を根とする木dp(各辺について1度ずつdp計算。ここで調べた方向を順方向とします)
・dfs2 : 順方向と逆方向の木dp

dfs1ではdp1を、dfs2ではdp2を更新していきますが、重要なのは「dp1[cu]はcuを黒く塗る通り数を、dp2[cu]はcuの1つ前の頂点を黒く塗る通り数を表している」ということです。

これを理解するために相当の時間を溶かしました。

import sys
sys.setrecursionlimit(10**6)

N,M=map(int,input().split())
xy=[[int(i) for i in input().split()] for _ in range(N-1)]

G=[[] for _ in range(N)]
for x,y in xy:
    x-=1; y-=1
    G[x].append(y)
    G[y].append(x)

#順方向のdp
#dp1[cu] : 順方向でcuを黒く塗る通り数
dp1=[1]*N
def dfs1(cu, p=-1):
    for to in G[cu]:
        if to==p: continue
        dfs1(to,cu)
        dp1[cu]*=dp1[to]+1
        dp1[cu]%=M

ans=[-1]*N

#逆方向のdp
#dp2[cu] : 逆方向で「cuの1つ前の頂点」を黒く塗る通り数
dp2=[1]*N
def dfs2(cu, p=-1):

    #dp1の左右からの累積配列を計算
    acc_l=[1]*len(G[cu])
    acc_r=[1]*len(G[cu])
    for i in range(len(G[cu])):
        to = G[cu][i]
        if i-1>=0:
            acc_l[i]=acc_l[i-1]
        if to==p:
            continue
        acc_l[i]*=dp1[to]+1
        acc_l[i]%=M
    for i in range(len(G[cu])-1, -1, -1):
        to = G[cu][i]
        if i+1<len(G[cu]):
            acc_r[i]=acc_r[i+1]
        if to==p:
            continue
        acc_r[i]*=dp1[to]+1
        acc_r[i]%=M

    #逆方向のdp
    for i in range(len(G[cu])):
        to = G[cu][i]
        if to==p:
            continue
        tmp=1
        if i-1>=0:
            tmp*=acc_l[i-1]
            tmp%=M
        if i+1<len(G[cu]):
            tmp*=acc_r[i+1]
            tmp%=M
        dp2[to]=(dp2[cu]*tmp+1)%M
        dfs2(to,cu)

    #ans[cu]=(p以外からのdp1)*(p(=cuの1つ前の頂点)からのdp2)
    ans[cu]=1
    if len(acc_r):
        ans[cu]=acc_r[0]
    if p>-1:
        ans[cu]*=dp2[cu]
        ans[cu]%=M

dfs1(0)
dfs2(0)
for a in ans:
    print(a)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0