Juliaの並列化手法について
今まで@threads
を用いた簡素な並列化ばかり行ってきてしまったのですが、ここにきて、使用メモリに応じてスレッド数を自動で変更できる機構を作成しようと思い、@spawn
を用いてこれを実現しようと試みました。
通常のfor文
次のような、正弦波をプロットするコードを作成してみました。
using Plots
total_items = 101
θ_list = range(-π、π,total_items)
f(i) = sin(θ_list[i])
y = zeros(total_items)
for i in 1:total_items
y[i] = f(i)
end
plot(θ_list, y)
これを回すとy
に値が代入されていき、正弦波が出力されます。
@threads
を用いた並列化
for文の頭に@threads
を追記するだけで、Juliaを立ち上げる際に指定した並列数N0
julia -t N0
で並列化を行ってくれます。
using Plots
using Base.Threads
total_items = 101
θ_list = range(-π,π,total_items)
f(i) = sin(θ_list[i])
y = zeros(total_items)
@threads for i in 1:total_items
y[i] = f(i)
end
plot(θ_list, y)
もちろんJULIA_NUM_THREADS
を用いても構いません。
@spawn
を用いた並列化
まず、サンプルコードを記述します。
N_trunc
はN0 >= N_trunc
とし、Julia実行時に指定した並列数よりも小さな並列数で計算を回してくれます。
using Plots
using Base.Threads
total_items = 101
θ_list = range(-π,π,total_items)
f(i) = sin(θ_list[i])
y = zeros(total_items)
N_trunc = 4
chunk_size, remainder = divrem(total_items, N_trunc)
sizes = [chunk_size + (i <= remainder) for i in 1:N_trunc]
start_indices = [1; cumsum(sizes[1:end-1]) .+ 1]
tasks = []
for id in 1:N_trunc
task = @spawn begin
start_id = start_indices[id]
stop_id = start_id + sizes[id] - 1
for i in start_id:stop_id
y[i] = f(i)
end
end
push!(tasks,task)
end
fetch.(tasks)
plot(θ_list, y)
説明
ここでは、1スレッドごとに担当する分配された配列のことをチャンクと呼ぶことにします。
-
divrem()
関数を用いて、for文の回数(ここではtotal_items
回)を大雑把にN_trunc
分割しchunk_size
とします。また余りの値もremainder
に格納します。julia> divrem(total_items, N0) (25, 1)
- 続いて、
N_trunc
に分割した際のチャンクの長さをリストsizes
として生成します。このとき、julia> sizes = [chunk_size + (i <= remainder) for i in 1:N_trunc] 4-element Vector{Int64}: 26 25 25 25
(i <= remainder)
の部分により、割り切れなかった余りの分を、最初の数チャンクに対して等分配するように自動化されます。 - それぞれのチャンクの最初の
index
を保持するリストstart_indices
を生成します。julia> start_indices = [1; cumsum(sizes[1:end-1]) .+ 1] 4-element Vector{Int64}: 1 27 52 77
ここまでで、タスクを分配する準備ができましたので、for文を用いてタスクを生成します。
tasks = []
for id in 1:N_trunc
task = @spawn begin
start_id = start_indices[id]
stop_id = start_id + sizes[id] - 1
for i in start_id:stop_id
y[i] = f(i)
end
end
push!(tasks,task)
end
最後にfetch.(tasks)
を実行すれば並列化された計算が始まります。
このような流れで、同等の並列化が、並列数を削減して行うことが可能となります。
(追記)Base.Semaphore による並列化
@antimon2 さまより、Base.Semaphore
を用いた手法をご提案いただきました。
詳しくかみ砕けていないため急ぎ早にはなりますが、例えば次のようなコードですと、同等の方法で並列化数に制限がかけられそうです。
using Plots
using Base.Threads
total_items = 101
θ_list = range(-π,π,total_items)
f(i) = sin(θ_list[i})
y = zeros(total_items)
sem = Base.Semaphore(N_trunc)
for i in 1:total_items
@spawn begin
Base.acquire(sem) do
y[i] = f(i)
end
end
end
plot(θ_list, y)
最後に
またもう少し賢い手法を見つけたら、引き続き更新していきます。