はじめに
- 先日投稿した「グラフの重み最小経路について」において、ダイクストラ法を使ってABC362の問題dを解こうとしたら、半数ほどのケースでTLEとなってしまった。
- 元々、単純なダイクストラ法では、頂点数Vに対して、$O(V^2)$の計算量がかかるところ、問題dの頂点数の制約は「$2×10^5$以下」だったので、AtCoderの計算量の上限目安$10^8$を突き抜けてしまっていた。
- しかし、ダイクストラ法にヒープを利用した高速化手法を使うと、辺の数Eと頂点数Vを用いて、計算量は$O(E・logV)$になるらしい。問題dの辺の数の制約も、頂点数と同じく「$2×10^5$以下」なので、高速化手法導入により計算量が上限目安$10^8$を下回ることが分かる1。
- ヒープについては、別の投稿でswiftでの実装を行ったので、これを利用する。c++やpythonでは標準でデータ構造としてのヒープが用意されているが、swiftで用意されているデータ構造は配列のみ。スタックやキューすら、swiftにはない2。
ダイクストラ法にヒープを導入する。
- まずは、おさらいとして、ヒープ導入前のダイクストラ法のコードを見る。
let (N,M,s) = [readLine()!.split(separator:" ").map{Int($0)!}].map{($0[0],$0[1],$0[2]-1)}[0]
var gs = [[(to:Int,weight:UInt64)]](repeating:[],count:N) // 頂点[i]に対して、(繋がっている頂点to、重みweight)のタプルであり、前述のgsと比べて「重み」の情報が追加されている。
for _ in 0..<M {
let (a,b,w) = [readLine()!.split(separator:" ").map{UInt64($0)!}].map{(Int($0[0])-1,Int($0[1])-1,$0[2])}[0]
gs[a].append((b,w))
gs[b].append((a,w))
}
//前述のseenと同様、頂点[i]が到達済みか判定する
var seen = [Bool](repeating:false,count:N)
//前記seedだけで無く、頂点sから頂点[i]への総距離d[i]を記録する。最短距離になるまで、塗り替えが行われる。
var d = [UInt64](repeating:INF,count:N)
let INF = UInt64.max // 初期値は∞の代わりに、整数の最大値としている。
d[s] = 0 // 頂点sから頂点sへの総距離は当然0となる。
while (true) { // このループは、最大N回転することもあり、遅さの要因の一つとなっている
var dmin_next = INF // 次の頂点を決定するための最短距離の初期値
var v_next:Int? // 次の頂点を格納する変数
/////////////////////////////////////
/// ここが遅さの原因 /// 線形走査しているので遅くなっている
/// ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓ ///
for v in 0..<N { //全ての頂点を走査する
if !seen[v] && d[v]<dmin_next { //「頂点vが未到達」かつ「頂点vへの距離計測済みかつ最短」
dmin_next = d[v] // 最短距離を上書き
v_next = v // 最短となる次の頂点を格納 (for文の中で上書きされる可能性有り)
}
}
/// ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ↑↑ ///
/// ここが遅さの原因 /// ここをヒープに変えると速くなる
/////////////////////////////////////
guard let v = v_next else {break} // 次の頂点が無くなったらループを抜ける
//次の頂点vの更に先の頂点(g.to)までの距離情報を得る
for g in gs[v] {
if d[g.to] > d[v] + g.weight { // d[g.to]の初期値はINF
d[g.to] = d[v] + g.weight
}
}
seen[v] = true // 頂点vは到達済み(観測済み)のフラグを立てる
}
// 結果を表示
for v in 0..<N {
print(v+1,terminator:":") //頂点番号を[i]方式から、問題文の番号に戻す
if d[v]<INF { // 頂点vまで到達できるかどうかを距離がINFか否かで判定
print(d[v]) // 距離(重みの合計)を表示
} else {
print("INF") // 到達できないとき、INFを表示
}
}
- 上記コードに「ここが遅さの原因」と書いてある箇所が線形走査となっており、O(N)の計算量となっている。
- ここをヒープに置き換え、計算量をO(logN)に引き下げる。
- 別の投稿で行ったヒープの実装は、要素をIntとしていたが、上記で線形走査を置き換えるヒープとしては、要素を(最短距離d:UInt64,頂点番号v:Int)のタプルとし、第1要素の大小で並び替える「ヒープ」を実装することとなる。
- 具体的なヒープのコードは、次のとおり
struct HeapTpl_UInt64 : CustomStringConvertible { // print文の出力をdescriptionで制御するためのプロトコル
// 実体は配列
private var data:[(UInt64,Int)] = [] // 初期値はブランク
var description: String { return data.description } // インスタンスのprint文で出力されるもの
// 最大ヒープ(true)か最小ヒープ(false)かを選択するBool(初期値はfalse)
private var maxHeap:Bool
init(_ maxHeap :Bool = false){
self.maxHeap = maxHeap
}
///
///親インデックスや子インデックスを取得するための関数
///
private func pIndex(_ index:Int) -> Int { //親ノードが格納されているindex
return (index + 1) / 2 - 1
}
private func s0Index(_ index:Int) -> Int? { //子ノード(左)が格納されているindex
let s0 = 2 * (index + 1) - 1
if s0 < data.count {
return s0
} else {
return nil
}
}
private func s1Index(_ index:Int) -> Int? { //子ノード(右)が格納されているindex
let s1 = 2 * (index + 1)
if s1 < data.count {
return s1
} else {
return nil
}
}
///大小比較する関数
func gt(_ l:(UInt64,Int),_ r:(UInt64,Int))->Bool{ // greater thanよりgt
return (l.0 >= r.0 ? maxHeap : !maxHeap )
}
///
///メソッド
///
//要素の追加
mutating func insert(_ v:UInt64,_ t:Int) {
data.append((v,t))
var index = data.count - 1
var value = (v,t)
while index > 0 {
let pindex = pIndex(index)
let pValue = data[pindex]
if gt(pValue,value) {return} // 親の方が大きければ終了
data[index] = pValue
data[pindex] = value
index = pindex
}
}
//根の値変更
mutating func chgRoot(_ v:UInt64,_ t:Int) {
var index = 0 //根のindex
var value = (v,t)
data[index] = value // 根を塗り替え
while true {
guard let s0index = s0Index(index) else {return} //子ノードが無ければ終了
let s0Value = data[s0index]
guard let s1index = s1Index(index) else { // 子ノードがs0のみの時、s0とのみ比較する
if gt(value,s0Value) {return} // 子ノードの方が小さければ終了
data[index] = s0Value
data[s0index] = value
return // s1が無いとき、下層ないためここで終了
}
let s1Value = data[s1index]
let (sValue,sindex) = gt(s0Value,s1Value) ? (s0Value,s0index) : (s1Value,s1index)
if gt(value,sValue) {return} // 子ノードの方が小さければ終了
data[index] = sValue
data[sindex] = value
index = sindex
}
}
//根の取出し
mutating func popRoot()->(UInt64,Int)? {
if data.isEmpty {
return nil
} else {
let ans = data[0]
let last = data.removeLast()
if data.count > 0 {
chgRoot(last.0,last.1)
}
return ans
}
}
// isEmptyプロパティを追加
var isEmpty:Bool {data.isEmpty}
}
var h = HeapTpl_UInt64()
print(h.isEmpty) // true
h.insert(5,1)
print(h.isEmpty) // false
h.insert(10,2)
h.insert(100,3)
h.insert(200,4)
h.insert(50,5)
h.insert(20,0)
print(h) // [(5, 1), (10, 2), (20, 0), (200, 4), (50, 5), (100, 3)]
h.chgRoot(55,1)
print(h) // [(10, 2), (50, 5), (20, 0), (200, 4), (55, 1), (100, 3)]
h.popRoot()
print(h) // [(20, 0), (50, 5), (100, 3), (200, 4), (55, 1)]
- 元のIntを要素としたヒープから(UInt64,Int)を要素とするヒープへのコード書き換えは楽勝だった。ついでに、デフォを「最大ヒープ」から「最小ヒープ」に変更しておいたのと、isEmptyプロパティを追加した。本来的なら、HeapTpl<Value,T>みたいにした方がシャレオツだけど面倒だわ...自分の競プロ用のライブラリに加えるときにはやろうと思うけど。
- 上記ヒープ
HeapTpl_UInt64
を用いて、単純なダイクストラのコードを書き換えると、
let (N,M,s) = [readLine()!.split(separator:" ").map{Int($0)!}].map{($0[0],$0[1],$0[2]-1)}[0]
var gs = [[(to:Int,weight:UInt64)]](repeating:[],count:N) // 頂点[i]に対して、(繋がっている頂点to、重みweight)のタプルであり、前述のgsと比べて「重み」の情報が追加されている。
for _ in 0..<M {
let (a,b,w) = [readLine()!.split(separator:" ").map{UInt64($0)!}].map{(Int($0[0])-1,Int($0[1])-1,$0[2])}[0]
gs[a].append((b,w))
gs[b].append((a,w))
}
//前述のseenと同様、頂点[i]が到達済みか判定する
// var seen = [Bool](repeating:false,count:N) -- ヒープで履歴管理するので、seenは不要に
//前記seenだけで無く、頂点sから頂点[i]への総距離d[i]を記録する。最短距離になるまで、塗り替えが行われる。
var d = [UInt64](repeating:INF,count:N)
let INF = UInt64.max // 初期値は∞の代わりに、整数の最大値としている。
d[s] = 0 // 頂点sから頂点sへの総距離は当然0となる。
// while (true) { -- なんとなく使ってたwhile(true)は退場!!
// var dmin_next = INF // 次の頂点を決定するための最短距離の初期値 -- 最短距離取得はヒープの役割に
// var v_next:Int? // 次の頂点を格納する変数 -- 最短距離取得はヒープの役割に
// for v in 0..<N { //全ての頂点を走査する -- 最短距離取得はヒープの役割に
// if !seen[v] && d[v]<dmin_next { //「頂点vが未到達」かつ「頂点vへの距離計測済みかつ最短」
// dmin_next = d[v] // 最短距離を上書き
// v_next = v // 最短となる次の頂点を格納 (for文の中で上書きされる可能性有り)
// }
// }
// guard let v = v_next else {break} // 次の頂点が無くなったらループを抜ける -- 履歴管理をヒープが担うため、while(true)が退場し、こいつもお役御免
/// ヒープ登場
var dvhp = HeapTpl_UInt64() //デフォは最小ヒープ
dvhp.insert(d[s],0) // とりあえず、起点0のd[s]をヒープに登録
while(!dvhp.isEmpty){ //ヒープが空っぽになるまで実行。これが意味のあるピンポイントなwhileループ !!!
var (dmin_v,v) = dvhp.popRoot()! //popの計算量はO(logN)
if dmin_v > d[v] {continue} // この条件となった(dmin_v,v)のタプルはゴミなので廃棄されてスキップ
//次の頂点vの更に先の頂点(g.to)までの距離情報を得る
for g in gs[v] { //gsの要素数は全体で2*M個
if d[g.to] > d[v] + g.weight { // d[g.to]の初期値はINF
d[g.to] = d[v] + g.weight
dvhp.insert(d[g.to],g.to) //ヒープに次の頂点の番号とその頂点までの最短距離を追加 insertの計算量はO(logN)
}
}
// seen[v] = true // 頂点vは到達済み(観測済み)のフラグを立てる -- seenは不要になった
} // while(true)のお尻から、while(!dvhp.isEmpty)のお尻になった
// 結果を表示
for v in 0..<N {
print(v+1,terminator:":") //頂点番号を[i]方式から、問題文の番号に戻す
if d[v]<INF { // 頂点vまで到達できるかどうかを距離がINFか否かで判定
print(d[v]) // 距離(重みの合計)を表示
} else {
print("INF") // 到達できないとき、INFを表示
}
}
- 上記のコードを説明すると、まず、ヒープdvhpを生成している。これは、「到達済み頂点」および「当該頂点への暫定的最短距離」を格納している。よって、到達済み頂点のフラグ管理のみ行っていた配列seenが不要になったので、消した。また、最短距離および当該最短距離の先となる頂点番号を保持していたdmin_nextとv_nextが不要となったので消した。
- また、当然、最短距離の次頂点を決めるアルゴリズムを担うのが最小ヒープになったので、線形走査部分を消した。
- 完成後、念のため、下記入力と出力結果が変わらないことを確認した。
6 9 1 // 頂点の数6、辺の数9,スタート頂点1
1 2 7 // 辺1:頂点1と頂点2を結ぶ辺の重み7
1 3 9 // 辺2:頂点1と頂点3を結ぶ辺の重み9
1 6 14
2 3 10
2 4 15
3 4 11
6 3 2
6 5 9
4 5 6
- 出力
1:0 //頂点1から頂点1への距離は当然0
2:7 //頂点1から頂点2への最短距離は7
3:9 //頂点1から頂点3への最短距離は9
4:20
5:20
6:11
再チャレンジ
- 今度こそ、ABC362の問題dでTLE無しでAC出来るか確認する。
- コードは、問題dにあわせて、下記の通り書き換えた。
let (N,M) = [readLine()!.split(separator:" ").map{Int($0)!}].map{($0[0],$0[1])}[0]
let As = readLine()!.split(separator:" ").map{UInt64($0)!} //頂点の重みを格納
var gs = [[(to:Int,weight:UInt64)]](repeating:[],count:N) // 頂点[i]に対して、(繋がっている頂点to、重みweight)のタプルであり、前述のgsと比べて「重み」の情報が追加されている。
for _ in 0..<M {
let (a,b,w) = [readLine()!.split(separator:" ").map{UInt64($0)!}].map{(Int($0[0])-1,Int($0[1])-1,$0[2])}[0]
gs[a].append((b,w + As[b])) //辺の「先」の頂点の重みを加算
gs[b].append((a,w + As[a])) //辺の「先」の頂点の重みを加算
}
var d = [UInt64](repeating:INF,count:N)
let INF = UInt64.max // 初期値は∞の代わりに、整数の最大値としている。
d[0] = 0 // 頂点sから頂点sへの総距離は当然0となる。
var dvhp = HeapTpl_UInt64() //デフォは最小ヒープ
dvhp.insert(d[0],0) // とりあえず、起点0のd[s]をヒープに登録
while(!dvhp.isEmpty){ //ヒープが空っぽになるまで実行。これが意味のあるピンポイントなwhileループ !!!
var (dmin_v,v) = dvhp.popRoot()!
if dmin_v > d[v] {continue} // この条件となった(dmin_v,v)のタプルはゴミなので廃棄されてスキップ
//次の頂点vの更に先の頂点(g.to)までの距離情報を得る
for g in gs[v] {
if d[g.to] > d[v] + g.weight { // d[g.to]の初期値はINF
d[g.to] = d[v] + g.weight
dvhp.insert(d[g.to],g.to)
}
}
}
// 結果を表示
for v in 1..<N { //スタートとなる頂点は[0]
print(d[v] + As[0],terminator:" ") // 最後に頂点[0]の重みを加算
}
- 無事に500ms程度でAC出来た!ちなみに、c++の最短は50ms程度、pythonで200ms程度...
- 速度がpythonに負けとる...これはショックでかいね。pythonみたいに型の辻褄合わせを全部コンパイラにお任せしちゃうだらしねぇ妖精言語に負けるのか...でも、仕方ないね!こっちは、素人が配列を背後で動かす形で自作したのに、pythonのheapqはプロが作った歪みねぇライブラリだしね!
最後に
- 今週末は、ABC362の問題dのお陰で、「ダイクストラ法」と「ヒープ」を学んだけど、まあ、なんとかモノに出来て良かったよ。
- 早く、4完したいなぁ。
-
疎グラフの時はヒープを使った方が高速になるが、密グラフの時は場合によってはヒープを使った方が遅くなる。疎グラフ、密グラフとは、頂点の数Vと比べて、辺の数Eが若干多い程度(疎グラフ)か、桁違いに多い(密グラフ)かの違いによる。例えば頂点同士を全て結んだ場合、辺の数はV(V-1)/2となり密グラフとなる。この場合、ヒープを使った計算量$O(E・logV)$は、$O(V^2・logV)$となってしまい、単純なダイクストラ法の計算量$O(V^2)$を上回る。 ↩
-
キューやヒープが含まれるapple謹製のライブラリ「swift-collections」がgithubで公開されているので、そのうち、正式にswiftの標準ライブラリになって、paiza.ioやAtCoderでも使えるようになるかも。swiftの直近バージョンは5.9なので、バージョン6.0になるときに標準ライブラリになったら良いなぁ。swift-collectionsのバージョンも1.1.2で、1.0を超えてるんだから、正式に利用できても良さそうなのにね。 ↩