初めに
初めまして、塚っちゃん(Hiro_Gt)と申します。
この記事では、Pythonで実装したHeavy-Light分解(HL分解、Heavy-Light Decomposition)のライブラリを紹介したいと思います。
特にHL分解の仕組みよりも、自分が作成したライブラリの使用方法について解説したいと思います。
想定読者ですが、AtCoderで水色は達成していてセグメント木を理解している人だとして話を進めています。
参考
実装方針の大元は上記の記事を参考にしています。
HL分解で出来ることについては以下の記事も参照しました。
HL分解で出来ること
下記のプログラムは、重み付き木構造において以下の2つの処理を高速に行うことが出来ます。
- set:木構造上の隣接した2頂点間の辺の重みを変更する
- prod:木構造上の任意の2頂点間の最短距離における、辺の重みについての何らかの演算を行う
一般的な言葉の表現をしましたが、例えば以下のような問題を高速に解くことが出来ます。
$N$ 頂点で辺の数が$N-1$ 個の重み付き木が与えられます。以下の2種類のクエリを$Q$ 個処理してください。
- 辺 $i$ の重みを $w$ に変更する
- 頂点 $u$ と頂点 $v$ の間の距離(最短距離のパスに含まれる辺の重みの合計)を出力する
制約:$1 \leq N \leq 10^5 , 1 \leq Q \leq 10^5$
例えば上記の問題の場合、処理prodにおける演算は「辺の重みの和」になります。
なお、私が作成したHL分解のライブラリにおいては、処理prodの演算は以下の条件を満たす必要があります。
- 演算はモノイドである。つまり、以下の3条件を満たす(=セグメント木において成立する演算)
- 演算がある集合 $S$ において閉じている。つまり、$x \in S, y \in S$ において、$(x * y) \in S$ が成立する
- 演算に単位元が存在する。つまり、$x * e = e * x = x$ となる$e$ が存在する
- 演算は結合法則を満たす。つまり、$(x * y) * z = x * (y * z)$ が成立する
- 演算は可換である。つまり、$x * y = y * x$ が成立する
前者の条件について理解できない方はセグメント木を勉強すれば良いと思います。
後者の条件については、ライブラリの改良を行えば非可換でも成立させられるのではないか、と考えています。が、大抵の木構造上での演算は可換なことが多いと思うためそれほど支障がないように思われます。恐らく…。
例えば上記の問題での演算は「辺の重みについての和」になりますが、$1+3 = 3+1 = 4$ が成立するように要素を入れ替えても演算が成立します。
コードの使い方
初期化
HL分解のclassのインスタンス時において必要なのは以下の4つです。
- N:木構造の頂点数
- G:木構造の配列(0-indexed)
- ここでは、
G[i] = [(頂点iと隣接している頂点, 頂点i ~ G[i][0]間の重み)]
となるような配列を用意してください。
- ここでは、
- op:木構造上で求めたい演算の関数オブジェクト
- イメージとしては、頂点1〜2間の重みと頂点2〜3間の重みについてどんな演算を行いたいか、を指定してください。
- e:演算opの単位元
例えば下図のような木構造で、任意の2頂点間における辺の重みの合計値を求めたい場合、次のようなコードを書くことになります。(図では1-indexed, 配列では0-indexedで表現されていることに注意してください)
N = 6
G = [ [ (1, 3) ],
[ (0, 3), (2, 4) ],
[ (1, 4), (3, 5), (5, 7) ],
[ (2, 5), (4, 6) ],
[ (3, 6) ],
[ (2, 7) ] ]
def op(left, right):
return left + right
e = 0
HLD = HLDecomposition(N, G, op, e)
処理set, prodの使い方
-
set(v, w, weight)
- 頂点 $v$ と頂点 $w$ 間の辺の重みを $weight$ に変更します。(0-indexed)
-
prod(l, r)
- 頂点 $l$ と頂点 $r$ 間における辺の重みに関する演算結果を出力します。(0-indexed)
特に具体的な使用例は、以下の問題AC提出をご確認ください。
注意点
自分の実装方法では再帰関数を使用しているので、PyPyよりもCPythonで提出する方が良いと思われます。
また(恐らく自分の実装方法が冗長なせいで)実行時間にギリギリ間に合うか…といった速度になっています。そのため下記ライブラリではinput = sys.stdin.readline
を用いて標準入力受け取りの高速化コードを最初から記入しています。
もしもっと短縮できる実装方法があれば教えて頂きたい…。
HL分解 python実装コード
# 再帰関数の上限値解放
# 大抵の場合CPythonで提出する方が早いように思われます
import sys
sys.setrecursionlimit(10 ** 7)
# 標準入力受け取りの高速化
input = sys.stdin.readline
class HLDecomposition:
# N...木構造の頂点数
# G[i] = [(頂点iと隣接している頂点, i ~ G[i][0]間の重み), ...], 0-indexed
# op(left, right)...セグメント木上の演算、木上のパスの何を求めたいか
# opは可換であることを想定
# e...opの単位元
def __init__(self, N, G, op, e):
self.N = N
self.G = G
self.op = op
self.e = e
# self.parent[i] = 頂点iの親ノード
# self.depth[i] = 頂点iの木上の深さ、つまり頂点0までの距離
# self.size[i] = 頂点iより子ノードの数(i自身も含める)
# self.heavy[i] = 頂点iの子ノードのうちheavy-edgeで隣接している頂点
# -1の場合は子ノードにheavy-edgeを持たない
self.parent = [-1] * N
self.depth = [0] * N
self.size = [0] * N
self.heavy = [-1] * N
self.__build_1(0)
# self.top[i] = 頂点iが属するHeavy-edge上での根ノード
# self.HLD_G = 全てのHeavy-Pathが連続する頂点indexの配列
self.top = [0] * N
self.HLD_G = [0]
self.__build_2(0)
# 計算の都合上、逆順にしている
self.HLD_G = self.HLD_G[::-1]
# self.node_w[HLD_G[i]] = 頂点HLD_G[i]とその親ノードのパスの重み
# ただし頂点0のパスの重みは0とする
# HLD_Gに保持されている頂点の順にself.nodeが構成される
# self.node[i] = 頂点iに対応するTreeのindex
self.node_w = [0] * N
self.node = [N-1] * N
self.__make_tree()
# 頂点v ~ wを結ぶ辺の重みをweightに変更する
# 0-indexed
def set(self, v, w, weight):
# vが子ノード、wが親ノード想定
# 一応v ~ w間の辺が存在するかチェックしている
if self.parent[w] == v:
v, w = w, v
elif self.parent[v] != w:
return False
vi, wi = self.node[v], self.node[w]
self.Tree.set(vi, weight)
# 頂点l ~ rを結ぶパスのopの結果を出力する
# 0-indexed
def prod(self, l, r):
sml, smr = self.e, self.e
while True:
li, ri = self.node[l], self.node[r]
if self.top[l] == self.top[r]:
if self.depth[l] > self.depth[r]:
sml = self.op(sml, self.Tree.prod(li, ri))
else:
smr = self.op(self.Tree.prod(ri, li), smr)
return self.op(sml, smr)
else:
l_top, r_top = self.top[l], self.top[r]
if self.depth[l_top] > self.depth[r_top]:
sml = self.op(sml, self.Tree.prod(li, self.node[l_top]+1))
l = self.parent[l_top]
else:
smr = self.op(self.Tree.prod(ri, self.node[r_top]+1), smr)
r = self.parent[r_top]
# 各頂点の深さ、親ノード、自身を根とした時の子ノードの数、heavy辺を求める
# ただし木構造全体の根は頂点0を想定(0-indexed)
def __build_1(self, v0):
for v1, _ in self.G[v0]:
if v1 == self.parent[v0]:
continue
else:
self.parent[v1] = v0
self.depth[v1] = self.depth[v0] + 1
self.size[v0] += self.__build_1(v1)
temp = self.heavy[v0]
if temp == -1 or self.size[temp] < self.size[v1]:
self.heavy[v0] = v1
self.size[v0] += 1
return self.size[v0]
# topノードとHLDした結果を求める
def __build_2(self, v0):
if self.heavy[v0] != -1:
v1 = self.heavy[v0]
self.top[v1] = self.top[v0]
self.HLD_G.append(v1)
self.__build_2(v1)
for v1, _ in self.G[v0]:
if v1 == self.heavy[v0] or v1 == self.parent[v0]:
continue
else:
self.top[v1] = v1
self.HLD_G.append(v1)
self.__build_2(v1)
# HLDした結果でSegmentTreeを作成
def __make_tree(self):
for i, v0 in enumerate(self.HLD_G[:-1]):
for v1, w in self.G[v0]:
if v1 == self.parent[v0]:
self.node_w[i] = w
break
self.node[v0] = i
self.Tree = SegmentTree(self.op, self.e, self.N, self.node_w)
class SegmentTree:
# op(x, y)...演算の関数オブジェクト
# e...単位元
# n...使用する配列の長さ
# List...初期値の配列、指定しない場合は省略可能
def __init__(self, op, e, n, List=None):
self.op = op
self.e = e
self.n = n
self.log_2 = (self.n - 1).bit_length()
self.size = 1 << self.log_2
self.data = [e for _ in range(2 * self.size)]
if not List is None:
for i in range(self.n):
self.data[i + self.size] = List[i]
for i in range(self.size - 1, 0, -1):
self.data[i] = self.op(self.data[i << 1], self.data[(i << 1) | 1])
# 配列iをxに変更(iは0-indexed)
def set(self, i, x):
i += self.size
self.data[i] = x
while i > 1:
i >>= 1
self.data[i] = self.op(self.data[i << 1], self.data[(i << 1) | 1])
# l <= i < rの配列の要素の総積(0-indexed)
def prod(self, l, r):
sml, smr = self.e, self.e
l += self.size
r += self.size
while l < r:
if l & 1 == 1:
sml = self.op(sml, self.data[l])
l += 1
if r & 1 == 1:
r -= 1
smr = self.op(self.data[r], smr)
l >>= 1
r >>= 1
return self.op(sml, smr)
# 配列i(0-indexed)の要素を出力
def get(self, i):
return self.data[i + self.size]
# 配列全ての要素の総積を出力
def all_prod(self):
return self.data[1]
# f(prod(l, x))...xについて単調減少する条件式
# lが与えられた時、f(prod(l, x))=Trueとなる最大のxを出力
# ただしf(e)=True
def max_right(self, l, f):
if l == self.n:
return self.n
l += self.size
sm = self.e
while True:
while l & 1 == 0:
l >>= 1
if not f(self.op(sm, self.data[l])):
while l < self.size:
l <<= 1
if f(self.op(sm, self.data[l])):
sm = self.op(sm, self.data[l])
l += 1
return l - self.size
else:
sm = self.op(sm, self.data[l])
l += 1
if l & (-l) == l:
return self.n
# f(prod(x, r))...xについて単調増加する条件式
# rが与えられた時、f(prod(x, r))=Trueとなる最小のxを出力
# ただしf(e)=True
def min_left(self, r, f):
if r == 0:
return 0
r += self.size
sm = self.e
while True:
r -= 1
while r & 1 == 1 and r > 1:
r >>= 1
if not f(self.op(self.data[r], sm)):
while r < self.size:
r <<= 1
r += 1
if f(self.op(self.data[r], sm)):
sm = self.op(self.data[r], sm)
r -= 1
return r + 1 - self.size
else:
sm = self.op(self.data[r], sm)
if r & (-r) == r:
return 0