1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Juliaとcomputational graph

Posted at

Computational graph (計算グラフ)とは

数値計算では、互いに依存関係がある複数のデータを逐次的、もしくは、並列的に計算することがよくあります。そのような場合に、計算グラフの概念を使うと、それらの間の依存関係を把握したり、並列して計算できる部分を抽出することが容易になります。

例えば、C = A + Bという計算を考えてみましょう。A, B, Cは行列でも良いですし、単なるスカラーでも良いです。
この場合、Cの評価結果は、A, Bに依存しています。つまり、A, Bが確定しないと、Cは評価できません。

もう少し複雑な例を考えてみましょう。
以下の計算式を考えます:

  1. a = A + B
  2. b = C + D
  3. c = a + b

この場合、まずは、abを評価した後に、cの評価が可能になります。
ただし、a, bの演算は原理的に並列して計算できるはずです。

このような依存関係を表すために用いられるのが 有向非巡回グラフ(Directed Acyclic Graph, DAG) です。

    A       B       C       D
     │       │       │       │
     ▼       ▼       ▼       ▼
   +---+   +---+   +---+   +---+
   | A |   | B |   | C |   | D |
   +---+   +---+   +---+   +---+
     │       │       │       │
     └───┬───┘       └───┬───┘
         ▼               ▼
       +---+           +---+
       | a |           | b |
       +---+           +---+
         │               │
         └──────┬────────┘
                ▼
              +---+
              | c |
              +---+

有向(Directed)とは、各辺には方向性があることを意味します。たとえば、A -> B は「AがBに影響を与える」または「BはAに依存する」という意味を表します。非巡回(Acyclic)とは、グラフにサイクル(閉路)が存在しないことを意味します。
上のグラフをみれば、A + BC + Dが並列して実行できることが一目で分かります。

このように、データ間の依存関係をDAGで表した構造を計算グラフと呼びます。一旦、計算グラフに表してしまえば、演算の並列構造を依存関係を機械的に抽出し、並列して計算できる部分を複数のプロセッサに割りあてるようなことも可能になります。

もちろん、Graphs.jlなどのような汎用的なグラフライブラリをつかって、計算グラフを実装することも出来ますが、ここでは、並列計算にも対応したDagger.jlを使ってみましょう。

上の構造をDagger.jlを使って実装すると、以下の様なコードになります。

using Dagger

# 入力ノードを定義
A = Dagger.@spawn () -> 1 # 定数
B = Dagger.@spawn () -> 2
C = Dagger.@spawn () -> 3
D = Dagger.@spawn () -> 3

# 中間ノードの計算
a = Dagger.@spawn A + B
b = Dagger.@spawn C + D

# 最終ノードの計算
c = Dagger.@spawn a + b

@show fetch(c)

ここで、@spawnは、後ろに続く式を使ってDaggerのタスク(計算ノード)を定義します。@spawnによって作られる各計算タスクは依存関係を持つことができ、それらの依存に基づいて計算グラフ(DAG)が自動的に構築されます。
ここで注意すべきは、@spawn自体は依存関係を生成するだけで、実際の計算は最後にfetch(c)が呼ばれたときに行われます。
つまり、cの値を評価するのに必要な全てのタスクが、依存関係に従って順次評価されます。

  1. 依存関係の解析
  • cを評価するために、まずabが評価される必要があります。
  • aを評価するためにはABbを評価するためにはCDがそれぞれ必要です。
  1. 計算の実行
  • fetch(c)が呼ばれると、cが依存している全てのタスクが再帰的に評価されます。
  • 評価順序は、Daggerが内部的に構築したDAGに基づいて決定されます。
  1. 並列実行の最適化 (後で説明します)
  • 依存関係がないタスク(たとえば、AB、またはCD)は並列に実行可能です。

上のコードを実行すると、

fetch(c) = 9

が得られるでしょう。でも、これだけだと、計算グラフを使うメリットがわかりにくいですよね。逐次的に、各ステップを即時的に計算・評価するのと結果は変わりません。

並列計算 Dagger.jl

次に、並列計算を試してみましょう。
今回試すのは、Distributedベースのプロセス並列です。
大きさN\times Nの2つの行列積計算をnterm回繰り返して、すべての結果の和を計算するという簡単なセットアップです。並列効率を出すために、N=6000に設定しています (1回の行列積計算は10秒ほど)。

色々試行錯誤した結果、以下のコードが正しく動作するようです。

using Distributed
using Dagger
using LinearAlgebra

BLAS.set_num_threads(1)
@show nworkers()

N = 6000

@everywhere function func(x, y)
    flush(stdout)
    t1 = time_ns()
    res = sum(x * y)
    t2 = time_ns()
    println("computed $((t2-t1)/1e+9)")
    return res
end

@everywhere func2(x...) = sum(x)

nterm = 32

matrices = []

for n in 1:nterm
    m1 = Dagger.@spawn ()->rand(N, N)
    m2 = Dagger.@spawn ()->rand(N, N)
    push!(matrices, (m1, m2))
end

products = []
for n in 1:nterm
   p = Dagger.@spawn func(matrices[n]...)
   push!(products, p)
end

@show "A"

res = Dagger.@spawn func2(products...)

@show "B"
@time sum(fetch(res))
@show "C"

ポイントとしては、Dagger.@spawnには必ず関数呼び出しの形、@spawn f(x, y)、の形でタスクを定義することが重要です。ここで、x, yは固定された値でも良いですし、別の@spawnの評価結果でも大丈夫です。一方、@spawn x + yな形だと上手く動作しません。

私の手元のワークステーション (64コア)で実行してみたところ、

16ワーカー (julia -p 16 -t 1 test.jl)

nworkers() = 16
"A" = "A"
"B" = "B"
      From worker 16:	computed 8.993550151
      From worker 6:	computed 9.047431069
      From worker 4:	computed 9.041195385
      From worker 8:	computed 9.039754094
      From worker 2:	computed 9.029207313
      From worker 15:	computed 9.005605376
      From worker 12:	computed 9.0252073
      From worker 5:	computed 9.028951724
      From worker 10:	computed 9.073104744
      From worker 3:	computed 9.017677308
      From worker 14:	computed 9.020359432
      From worker 12:	computed 8.880520181
      From worker 3:	computed 8.921722932
      From worker 16:	computed 8.99813879
      From worker 4:	computed 8.904876451
      From worker 9:	computed 8.900564684
      From worker 6:	computed 8.985334225
      From worker 8:	computed 8.911503034
      From worker 2:	computed 9.01582081
      From worker 7:	computed 8.928685454
      From worker 13:	computed 8.900212701
      From worker 16:	computed 8.839050387
      From worker 14:	computed 8.781535921
      From worker 8:	computed 8.821355677
      From worker 15:	computed 8.815256977
      From worker 17:	computed 8.775349723
      From worker 7:	computed 8.808136017
      From worker 13:	computed 8.792752806
      From worker 6:	computed 8.84895146
      From worker 4:	computed 8.784235002
      From worker 7:	computed 8.787864021
      From worker 13:	computed 8.747277693
 43.638821 seconds (1.75 M allocations: 88.946 MiB, 0.03% gc time, 84 lock conflicts, 4.00% compilation time: 4% of which was recompilation)
"C" = "C"

32ワーカー (julia -p 32 -t 1 test.jl)

nworkers() = 32
"A" = "A"
"B" = "B"
      From worker 23:	computed 10.55963046
      From worker 11:	computed 10.587748061
      From worker 3:	computed 10.603463832
      From worker 20:	computed 10.652398373
      From worker 13:	computed 10.628970288
      From worker 24:	computed 10.611196972
      From worker 16:	computed 10.476117871
      From worker 30:	computed 10.482844322
      From worker 7:	computed 10.494269832
      From worker 26:	computed 10.51448675
      From worker 2:	computed 10.638734037
      From worker 9:	computed 10.640405083
      From worker 21:	computed 10.517608783
      From worker 18:	computed 10.517825926
      From worker 32:	computed 10.527195295
      From worker 10:	computed 10.507530409
      From worker 33:	computed 10.636384727
      From worker 5:	computed 10.623993734
      From worker 31:	computed 10.518583279
      From worker 28:	computed 10.619155467
      From worker 12:	computed 10.621142638
      From worker 4:	computed 10.598856911
      From worker 17:	computed 10.709858187
      From worker 22:	computed 10.473410367
      From worker 11:	computed 8.818148052
      From worker 20:	computed 8.874088159
      From worker 3:	computed 8.805013057
      From worker 9:	computed 8.850652172
      From worker 23:	computed 8.785906518
      From worker 29:	computed 8.79020795
      From worker 6:	computed 8.789653886
      From worker 13:	computed 8.809982671
 28.982518 seconds (1.76 M allocations: 90.266 MiB, 0.04% gc time, 124 lock conflicts, 6.43% compilation time: 4% of which was recompilation)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?