LoginSignup
20
23

More than 3 years have passed since last update.

Juliaでのスレッド並列を調べてみた

Last updated at Posted at 2020-10-02

Juliaでのプロセス並列についての情報はそれなりにありますが、スレッド並列についての情報があまりありません。
公式のサイトには@threadsを使ったForループのスレッド並列化について書かれてあります。
しかし、スレッド並列をもう少し詳細にやるにはどうするのでしょうか?
まだ完全に動作についてわかっていませんが、
https://www.oxinabox.net/2017/11/20/Thread-Parallelism-in-Julia.html
に有益な情報がありましたので、現在のバージョン(1.5)で試してみました。

環境

  • Julia 1.5.0
  • Mac OS 10.14.6

コード例

わかり次第色々追記したいと思いますが、とりあえずスレッドごとに異なる動作をさせることに成功しました。
まず、4コアであれば

export JULIA_NUM_THREADS=4

とすると4スレッドで走る設定になります。そして、以下のコード:

using Base.Threads
@show nthreads()

Threads.threading_run() do 
    i = threadid()
    if i == 1
        println("Yes! I am 1: id is $i")
    elseif i == 2
        println("Yes! He is 2 id is $i")
    elseif i == 3
        println("Yes! She is 3 id is $i")
    elseif i == 4
        println("Yes! We are 4 id is $i")
    end
    #called[threadid()] = true
end

を実行すると、

nthreads() = 4
Yes! I am 1: id is 1
Yes! We are 4 id is 4
Yes! She is 3 id is 3
Yes! He is 2 id is 2

となります。

ここでポイントはthreading_run()です。
ソース
https://github.com/JuliaLang/julia/blob/e49e83254db641ac91aea328dbccbc4282409101/base/threadingconstructs.jl
を見ると、
リンク先にあったccallで呼んでいたjl_threading_runはJuliaのこの関数に置き換えられたようです。

もう少し複雑なことができるか、試してみたいと思います。

追記1

次に、threading_run()の内部で関数を呼んだらどうなるかを調べてみました。以下のコードを実行してみました。

using Base.Threads
@show nthreads()

function test()
    id = threadid()
    n = nthreads()
    nmax = 100
    nbun = div(nmax,n)
    ista = (id-1)*nbun + 1
    iend = ista + nbun -1
    println("$ista $iend at $id")
    a = 0
    for i=ista:iend
        a += i
    end
    return a
end

data = zeros(Int64,nthreads())
Threads.threading_run() do 
    i = threadid()
    data[i] = test()
end
println(data)
println(sum(data))

実行結果は、

nthreads() = 4
76 100 at 4
26 50 at 2
51 75 at 3
1 25 at 1
[325, 950, 1575, 2200]
5050

となり、関数testの中でのスレッド番号はちゃんとスレッドごとに異なる値になっているようです。
これによって、for文を手動で並列化できました。

追記2

@antimon2 さんのコメントがとても勉強になります。
@spawnについても調べていきたいと思います。

また、Julia 1.5.xからは

julia -t 4 test.jl

のようにすると4スレッドで走るようです。これは便利ですね。

追記3

次に、より実用的な場合を考えて、引数に配列を入れてみましょう。

function test!(data)
    id = threadid()
    n = nthreads()
    nmax = 100
    nbun = div(nmax,n)
    ista = (id-1)*nbun + 1
    iend = ista + nbun -1
    println("$ista $iend at $id")
    a = 0
    for i=ista:iend
        a += i
    end

    data[id] = a
end

これを実行すると

nthreads() = 4
76 100 at 4
26 50 at 2
51 75 at 3
1 25 at 1
[325, 950, 1575, 2200]
5050

となり、ちゃんと配列内の要素が更新されていました。

ループがスレッド数で割り切れるかどうかを意識しないのであれば、

module Parallel
    using Base.Threads
    function get_looprange(N)
        id = threadid()
        n = nthreads()
        dn = N % n
        nbun = div(N - dn,n)
        ista = (id-1)*nbun + 1 
        ista += ifelse(id <= dn,(id-1),dn)
        nbun += ifelse(id <= dn,1,0)
        iend = ista + nbun - 1
        return ista:iend
    end
end

というモジュールを作っておけば、

function test!(data)
    nmax = 100
    id = threadid()
    ran = Parallel.get_looprange(nmax)
    println("$ran at $id")
    a = 0
    for i=ran
        a += i
    end
    data[id] = a
    return 
end

となり、実行すると、100が割り切れない3並列でも

nthreads() = 3
1:34 at 1
35:67 at 2
68:100 at 3
[595, 1683, 2772]
5050

と正しい結果が得られました。

20
23
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
23