Juliaでのプロセス並列についての情報はそれなりにありますが、スレッド並列についての情報があまりありません。
公式のサイトには@threads
を使ったForループのスレッド並列化について書かれてあります。
しかし、スレッド並列をもう少し詳細にやるにはどうするのでしょうか?
まだ完全に動作についてわかっていませんが、
https://www.oxinabox.net/2017/11/20/Thread-Parallelism-in-Julia.html
に有益な情報がありましたので、現在のバージョン(1.5)で試してみました。
環境
- Julia 1.5.0
- Mac OS 10.14.6
コード例
わかり次第色々追記したいと思いますが、とりあえずスレッドごとに異なる動作をさせることに成功しました。
まず、4コアであれば
export JULIA_NUM_THREADS=4
とすると4スレッドで走る設定になります。そして、以下のコード:
using Base.Threads
@show nthreads()
Threads.threading_run() do
i = threadid()
if i == 1
println("Yes! I am 1: id is $i")
elseif i == 2
println("Yes! He is 2 id is $i")
elseif i == 3
println("Yes! She is 3 id is $i")
elseif i == 4
println("Yes! We are 4 id is $i")
end
#called[threadid()] = true
end
を実行すると、
nthreads() = 4
Yes! I am 1: id is 1
Yes! We are 4 id is 4
Yes! She is 3 id is 3
Yes! He is 2 id is 2
となります。
ここでポイントはthreading_run()
です。
ソース
https://github.com/JuliaLang/julia/blob/e49e83254db641ac91aea328dbccbc4282409101/base/threadingconstructs.jl
を見ると、
リンク先にあったccallで呼んでいたjl_threading_run
はJuliaのこの関数に置き換えられたようです。
もう少し複雑なことができるか、試してみたいと思います。
追記1
次に、threading_run()
の内部で関数を呼んだらどうなるかを調べてみました。以下のコードを実行してみました。
using Base.Threads
@show nthreads()
function test()
id = threadid()
n = nthreads()
nmax = 100
nbun = div(nmax,n)
ista = (id-1)*nbun + 1
iend = ista + nbun -1
println("$ista $iend at $id")
a = 0
for i=ista:iend
a += i
end
return a
end
data = zeros(Int64,nthreads())
Threads.threading_run() do
i = threadid()
data[i] = test()
end
println(data)
println(sum(data))
実行結果は、
nthreads() = 4
76 100 at 4
26 50 at 2
51 75 at 3
1 25 at 1
[325, 950, 1575, 2200]
5050
となり、関数testの中でのスレッド番号はちゃんとスレッドごとに異なる値になっているようです。
これによって、for文を手動で並列化できました。
追記2
@antimon2 さんのコメントがとても勉強になります。
@spawn
についても調べていきたいと思います。
また、Julia 1.5.xからは
julia -t 4 test.jl
のようにすると4スレッドで走るようです。これは便利ですね。
追記3
次に、より実用的な場合を考えて、引数に配列を入れてみましょう。
function test!(data)
id = threadid()
n = nthreads()
nmax = 100
nbun = div(nmax,n)
ista = (id-1)*nbun + 1
iend = ista + nbun -1
println("$ista $iend at $id")
a = 0
for i=ista:iend
a += i
end
data[id] = a
end
これを実行すると
nthreads() = 4
76 100 at 4
26 50 at 2
51 75 at 3
1 25 at 1
[325, 950, 1575, 2200]
5050
となり、ちゃんと配列内の要素が更新されていました。
ループがスレッド数で割り切れるかどうかを意識しないのであれば、
module Parallel
using Base.Threads
function get_looprange(N)
id = threadid()
n = nthreads()
dn = N % n
nbun = div(N - dn,n)
ista = (id-1)*nbun + 1
ista += ifelse(id <= dn,(id-1),dn)
nbun += ifelse(id <= dn,1,0)
iend = ista + nbun - 1
return ista:iend
end
end
というモジュールを作っておけば、
function test!(data)
nmax = 100
id = threadid()
ran = Parallel.get_looprange(nmax)
println("$ran at $id")
a = 0
for i=ran
a += i
end
data[id] = a
return
end
となり、実行すると、100が割り切れない3並列でも
nthreads() = 3
1:34 at 1
35:67 at 2
68:100 at 3
[595, 1683, 2772]
5050
と正しい結果が得られました。