はじめに
問題はこちら
初心者(灰色〜茶色)向けです。
伝えたいこと
Pythonで再帰を使うときは、
- PyPyを使わない!
- pypyjitライブラリが必要!
import pypyjit pypyjit.set_param('max_unroll_recursion=-1')
D - Minimum Steiner Tree
考え方
大枠の方針
指定された頂点たち($V_1,...,V_k$)を一つ固定します。(ここでは$V_1$とします。)
不要な頂点を削除していき、求めたい状況の木になるのは、$V_1$を根とした時、葉が全て$V_1,...,V_k$たちのみで構成されているときです。1
これは、各ノードについて、子孫ノードに$V_1,...,V_k$が存在しなければそのノードを取り除いていき、残った頂点数を数えあげれば良いです。
具体的には、$V_1$を起点とした再帰関数を用いたDFSを行い、削除対象のノードを記録して数え上げることで計算可能です。
詳細
次のような再帰関数dfsを作成します。
- 自分が($V_1,...,V_k$)である場合はTrueを返す
- 1.がFalseの場合は自分の頂点が($V_1,...,V_k$)を子孫ノードに持つかどうかを返す
- 自分と隣接する未探索の頂点全てで関数dfsを実行し、自分の子孫ノード全てが($V_1,...,V_k$)を子孫ノードに持たない場合は削除対象のSetに保存する
$V_1$についてdfsを実行し、削除対象のSetの頂点数をnからひく。
注意
PyPyの場合、DFSは遅くなるらしいのです。。。
PyPyを使わずに提出するか、参考文献のように、pypyjitライブラリを用いる必要があります。
3種類提出し、実行時間を比較してみました。
- 一番上 PyPy かつ pypyjit利用あり AC
- 真ん中 CPython AC
- 一番下 PyPy かつ pypyjit利用なし TLE
解答例
#!/usr/bin/env python3
import sys
import pypyjit
from collections import defaultdict
sys.setrecursionlimit(10**6)
pypyjit.set_param('max_unroll_recursion=-1')
def dfs(v):
visited[v-1] = True
hasV_kChildNode = v in v_k_set
for neighbor in tree[v]:
if not visited[neighbor-1]:
if dfs(neighbor):
hasV_kChildNode = True
if not hasV_kChildNode:
removed.add(v)
return hasV_kChildNode
# 入力の読み込み
n, k = map(int, input().split())
tree = defaultdict(set)
for _ in range(n - 1):
a, b = map(int, input().split())
tree[a].add(b)
tree[b].add(a)
v_k = list(map(int, input().split()))
# 初期化
visited = [False] * n
removed = set()
v_k_set = set(v_k)
# DFS開始
dfs(v_k[0])
# 残ったノード数の出力
print(n - len(removed))
感想
アルゴリズムそんなに違和感がないのになぜ間に合わないのか、、、と数時間頭を悩ませたのですが、pypyjitライブラリの利用でACになりました。そもそもギリギリ間に合ってなかったのでそれほど悔しくなかったのですが、大きな学びでした。
先輩Yさんに教えていただきました。ありがとうございました。
参考
-
もし最小にならないと仮定すると、最小の木について、$V_1,...,V_k$のいずれでもない葉$L_0$が存在することになりますが、$L_0$を取り除いた木は、さらに小さい頂点数で$V_1,...,V_k$を全て含むため、取り除く前の木が最小であることに反します。 ↩