Computational graph (計算グラフ)とは
数値計算では、互いに依存関係がある複数のデータを逐次的、もしくは、並列的に計算することがよくあります。そのような場合に、計算グラフの概念を使うと、それらの間の依存関係を把握したり、並列して計算できる部分を抽出することが容易になります。
例えば、C = A + B
という計算を考えてみましょう。A
, B
, C
は行列でも良いですし、単なるスカラーでも良いです。
この場合、C
の評価結果は、A
, B
に依存しています。つまり、A
, B
が確定しないと、C
は評価できません。
もう少し複雑な例を考えてみましょう。
以下の計算式を考えます:
a = A + B
b = C + D
c = a + b
この場合、まずは、a
とb
を評価した後に、c
の評価が可能になります。
ただし、a
, b
の演算は原理的に並列して計算できるはずです。
このような依存関係を表すために用いられるのが 有向非巡回グラフ(Directed Acyclic Graph, DAG) です。
A B C D
│ │ │ │
▼ ▼ ▼ ▼
+---+ +---+ +---+ +---+
| A | | B | | C | | D |
+---+ +---+ +---+ +---+
│ │ │ │
└───┬───┘ └───┬───┘
▼ ▼
+---+ +---+
| a | | b |
+---+ +---+
│ │
└──────┬────────┘
▼
+---+
| c |
+---+
有向(Directed)とは、各辺には方向性があることを意味します。たとえば、A -> B は「AがBに影響を与える」または「BはAに依存する」という意味を表します。非巡回(Acyclic)とは、グラフにサイクル(閉路)が存在しないことを意味します。
上のグラフをみれば、A + B
とC + D
が並列して実行できることが一目で分かります。
このように、データ間の依存関係をDAGで表した構造を計算グラフと呼びます。一旦、計算グラフに表してしまえば、演算の並列構造を依存関係を機械的に抽出し、並列して計算できる部分を複数のプロセッサに割りあてるようなことも可能になります。
もちろん、Graphs.jl
などのような汎用的なグラフライブラリをつかって、計算グラフを実装することも出来ますが、ここでは、並列計算にも対応したDagger.jl
を使ってみましょう。
上の構造をDagger.jl
を使って実装すると、以下の様なコードになります。
using Dagger
# 入力ノードを定義
A = Dagger.@spawn () -> 1 # 定数
B = Dagger.@spawn () -> 2
C = Dagger.@spawn () -> 3
D = Dagger.@spawn () -> 3
# 中間ノードの計算
a = Dagger.@spawn A + B
b = Dagger.@spawn C + D
# 最終ノードの計算
c = Dagger.@spawn a + b
@show fetch(c)
ここで、@spawn
は、後ろに続く式を使ってDaggerのタスク(計算ノード)を定義します。@spawnによって作られる各計算タスクは依存関係を持つことができ、それらの依存に基づいて計算グラフ(DAG)が自動的に構築されます。
ここで注意すべきは、@spawn
自体は依存関係を生成するだけで、実際の計算は最後にfetch(c)
が呼ばれたときに行われます。
つまり、c
の値を評価するのに必要な全てのタスクが、依存関係に従って順次評価されます。
- 依存関係の解析
-
c
を評価するために、まずa
とb
が評価される必要があります。 -
a
を評価するためにはA
とB
、b
を評価するためにはC
とD
がそれぞれ必要です。
- 計算の実行
-
fetch(c)
が呼ばれると、c
が依存している全てのタスクが再帰的に評価されます。 - 評価順序は、Daggerが内部的に構築したDAGに基づいて決定されます。
- 並列実行の最適化 (後で説明します)
- 依存関係がないタスク(たとえば、
A
とB
、またはC
とD
)は並列に実行可能です。
上のコードを実行すると、
fetch(c) = 9
が得られるでしょう。でも、これだけだと、計算グラフを使うメリットがわかりにくいですよね。逐次的に、各ステップを即時的に計算・評価するのと結果は変わりません。
並列計算 Dagger.jl
次に、並列計算を試してみましょう。
今回試すのは、Distributed
ベースのプロセス並列です。
大きさN\times N
の2つの行列積計算をnterm
回繰り返して、すべての結果の和を計算するという簡単なセットアップです。並列効率を出すために、N=6000
に設定しています (1回の行列積計算は10秒ほど)。
色々試行錯誤した結果、以下のコードが正しく動作するようです。
using Distributed
using Dagger
using LinearAlgebra
BLAS.set_num_threads(1)
@show nworkers()
N = 6000
@everywhere function func(x, y)
flush(stdout)
t1 = time_ns()
res = sum(x * y)
t2 = time_ns()
println("computed $((t2-t1)/1e+9)")
return res
end
@everywhere func2(x...) = sum(x)
nterm = 32
matrices = []
for n in 1:nterm
m1 = Dagger.@spawn ()->rand(N, N)
m2 = Dagger.@spawn ()->rand(N, N)
push!(matrices, (m1, m2))
end
products = []
for n in 1:nterm
p = Dagger.@spawn func(matrices[n]...)
push!(products, p)
end
@show "A"
res = Dagger.@spawn func2(products...)
@show "B"
@time sum(fetch(res))
@show "C"
ポイントとしては、Dagger.@spawn
には必ず関数呼び出しの形、@spawn f(x, y)
、の形でタスクを定義することが重要です。ここで、x
, y
は固定された値でも良いですし、別の@spawn
の評価結果でも大丈夫です。一方、@spawn x + y
な形だと上手く動作しません。
私の手元のワークステーション (64コア)で実行してみたところ、
16ワーカー (julia -p 16 -t 1 test.jl
)
nworkers() = 16
"A" = "A"
"B" = "B"
From worker 16: computed 8.993550151
From worker 6: computed 9.047431069
From worker 4: computed 9.041195385
From worker 8: computed 9.039754094
From worker 2: computed 9.029207313
From worker 15: computed 9.005605376
From worker 12: computed 9.0252073
From worker 5: computed 9.028951724
From worker 10: computed 9.073104744
From worker 3: computed 9.017677308
From worker 14: computed 9.020359432
From worker 12: computed 8.880520181
From worker 3: computed 8.921722932
From worker 16: computed 8.99813879
From worker 4: computed 8.904876451
From worker 9: computed 8.900564684
From worker 6: computed 8.985334225
From worker 8: computed 8.911503034
From worker 2: computed 9.01582081
From worker 7: computed 8.928685454
From worker 13: computed 8.900212701
From worker 16: computed 8.839050387
From worker 14: computed 8.781535921
From worker 8: computed 8.821355677
From worker 15: computed 8.815256977
From worker 17: computed 8.775349723
From worker 7: computed 8.808136017
From worker 13: computed 8.792752806
From worker 6: computed 8.84895146
From worker 4: computed 8.784235002
From worker 7: computed 8.787864021
From worker 13: computed 8.747277693
43.638821 seconds (1.75 M allocations: 88.946 MiB, 0.03% gc time, 84 lock conflicts, 4.00% compilation time: 4% of which was recompilation)
"C" = "C"
32ワーカー (julia -p 32 -t 1 test.jl
)
nworkers() = 32
"A" = "A"
"B" = "B"
From worker 23: computed 10.55963046
From worker 11: computed 10.587748061
From worker 3: computed 10.603463832
From worker 20: computed 10.652398373
From worker 13: computed 10.628970288
From worker 24: computed 10.611196972
From worker 16: computed 10.476117871
From worker 30: computed 10.482844322
From worker 7: computed 10.494269832
From worker 26: computed 10.51448675
From worker 2: computed 10.638734037
From worker 9: computed 10.640405083
From worker 21: computed 10.517608783
From worker 18: computed 10.517825926
From worker 32: computed 10.527195295
From worker 10: computed 10.507530409
From worker 33: computed 10.636384727
From worker 5: computed 10.623993734
From worker 31: computed 10.518583279
From worker 28: computed 10.619155467
From worker 12: computed 10.621142638
From worker 4: computed 10.598856911
From worker 17: computed 10.709858187
From worker 22: computed 10.473410367
From worker 11: computed 8.818148052
From worker 20: computed 8.874088159
From worker 3: computed 8.805013057
From worker 9: computed 8.850652172
From worker 23: computed 8.785906518
From worker 29: computed 8.79020795
From worker 6: computed 8.789653886
From worker 13: computed 8.809982671
28.982518 seconds (1.76 M allocations: 90.266 MiB, 0.04% gc time, 124 lock conflicts, 6.43% compilation time: 4% of which was recompilation)