8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Juliaについてもっと早く知りたかったこと

Last updated at Posted at 2023-05-30

はじめに

Juliaのプログラムは何も考えなくても高速に動くことが魅力ですが、それでもJuliaの力を最大限活用できるように注意する必要があります。しかし、それらの注意点は、自分で調べながらプログラムを書いているだけではなかなか気づきにくいものです。本稿では実際にJuliaを使ってみて、なかなか自分で気づきにくい細かな計算テクニックについて紹介します。

次に示すコードは、やりたいことは実現できているものの、Juliaまたは数値計算の仕組みについてよく知って知っていればより高速化できるプログラムの例です。本稿は、これらのコードの問題点が思いつかなかった人向けの記事になります。

#異なる正規分布から一つずつサンプルを抽出したい
using Distributions

N = 10^6
a = zeros(N)

for i in 1:N
    a[i] = rand(Normal(i,2.))
end
#ベクトルの各要素に繰り返し操作をしたい
N = 10^5

function f(x)
    return @. x^2 + 0.01
end


function g(x)
    for i in  1:N
        x = f(x)
    end
    return x
end

u = rand(N)
g(u)
#数列の和の計算

N = 10^7
a = rand(N)
f(x) =  x^2 + cos(x) 

@time sum([f(x) for x in 1:N])

for文はできるだけ避ける

乱数生成

for文を避けた方がいいのは当たり前ですが、意識しないと意外と気付きにくい場面もあります。 
基本的な例として、各要素が標準正規分布に従うベクトルを生成することを考えます。

using Distributions, BenchmarkTools

N = 10^6
a = zeros(N)

@benchmark  for i in 1:N
    a[i] = rand(Normal())
end

このように記述するのはfor文を使っているため非効率的であり、実際に

BenchmarkTools.Trial: 52 samples with 1 evaluation.
 Range (min … max):  94.737 ms … 111.934 ms  ┊ GC (min … max): 2.30% … 2.50%
 Time  (median):     96.967 ms               ┊ GC (median):    4.48%
 Time  (mean ± σ):   96.764 ms ±   2.342 ms  ┊ GC (mean ± σ):  3.94% ± 0.99%

                                                  █▆▆▁          
  ▇▆▆▆▆▁▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▄▆████▇▄▆▁▄▁▁▄ ▁
  94.7 ms         Histogram: frequency by time         97.4 ms <

 Memory estimate: 76.28 MiB, allocs estimate: 3998979.

とおよそ100msの計算時間がかかります。一方これをfor文を使わずに一気に乱数を生成させて

using Distributions, BenchmarkTools

N = 10^6
@benchmark  a = rand(Normal(),N)

と記述するだけで、

BenchmarkTools.Trial: 1223 samples with 1 evaluation.
 Range (min … max):  3.726 ms …   5.528 ms  ┊ GC (min … max): 0.00% … 14.31%
 Time  (median):     3.826 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.079 ms ± 525.970 μs  ┊ GC (mean ± σ):  3.16% ±  6.26%

   ▄█▇                                                         
  ▄███▇▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▅▄▂▂▁▃▄▄▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▂▆ ▃
  3.73 ms         Histogram: frequency by time        5.51 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.

と計算時間は約20倍高速になりました。

では次の例はどうでしょうか。

#異なる正規分布から一つずつサンプルを抽出したい
using Distributions

N = 10^6
a = zeros(N)

for i in 1:N
    a[i] = rand(Normal(i,2.))
end
BenchmarkTools.Trial: 24 samples with 1 evaluation.
 Range (min … max):  214.108 ms … 232.341 ms  ┊ GC (min … max): 2.05% … 2.24%
 Time  (median):     215.531 ms               ┊ GC (median):    2.10%
 Time  (mean ± σ):   216.145 ms ±   3.653 ms  ┊ GC (mean ± σ):  2.49% ± 0.50%

  █      ▂                                                       
  █▃▃▁▁▁▁██▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁
  214 ms           Histogram: frequency by time          232 ms <

 Memory estimate: 106.80 MiB, allocs estimate: 4998979.

今度は各要素が異なる正規分布に従うため、同じ手法を使うことは難しそうに思えます。しかし、正規分布の性質を用いて、正規乱数は標準正規分布から生成できることに気づけば次のように改善できます。

N = 10^6

mu = 1:N
@benchmark a = rand(Normal(0.,2.), N) .+ mu
BenchmarkTools.Trial: 9723 samples with 1 evaluation.
 Range (min … max):  389.641 μs …   3.815 ms  ┊ GC (min … max): 0.00% … 85.92%
 Time  (median):     480.662 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   509.577 μs ± 169.546 μs  ┊ GC (mean ± σ):  5.77% ± 11.17%

  ▁▃  ▆█▆                                                     ▃ ▁
  ██▄▇████▄▁▃▁▁▁▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆▇█ █
  390 μs        Histogram: log(frequency) by time       1.39 ms <

 Memory estimate: 1.53 MiB, allocs estimate: 6.

結果として、計算は約500倍高速化されました。

では、for文をさけて一気にサンプリングをかけるとどのように計算時間は短縮されるのでしょうか。Nこの乱数を生成するプログラムで実験してみた結果がこちらです。

using Distributions, BenchmarkTools, Statistics, Plots

N_list=[10^2, 10^3, 10^4, 10^5, 10^6, 10^7]


mean1 = zeros(length(N_list))

for j in 1:length(N_list)
    a = zeros(N)
    time = @benchmark for i in 1:N
        a[i] = rand(Normal(1,2))
    end
    mean1[j] = mean(time.times)
end

mean2 = zeros(length(N_list))

for j in 1:length(N_list)
    a = zeros(N)
    time = @benchmark begin
        a = rand(Normal(1,2), N)
    end
    mean2[j] = mean(time.times)
end


plot(N_list, mean1, label="for", marker=:circle,  xscale=:log10 , xlabel="N", ylabel="time")
plot!(N_list, mean2,label="at once", marker=:circle, xscale=:log10)

スクリーンショット 2023-05-26 17.03.52.png

このように、サンプル数Nにかかわらず、大きな定数倍だけプログラムが高速化されていることがわかります。計算量がNに依存しないのは、多分内部ですごく上手く処理をしているのでしょう。

線形代数の計算

線形代数の計算でも、for文を避け、できるだけdot演算に持ち込むことが大切です。
例えば、適当なサイズNのベクトル演算を考えます。内積などを全部for文で実装しようとしたプログラムと、dot演算を利用したプログラムの計算時間を比較してみると

using Distributions, BenchmarkTools, Statistics

N_list=[10^2, 10^3, 10^4, 10^5, 10^6, 10^7]


mean1 = zeros(length(N_list))

for j in 1:length(N_list)
    a = zeros(N)
    b = rand(N)
    c = rand(N)
    time = @benchmark for i in 1:N
        a[i] = (b[i]*c[i])^2
    end
    mean1[j] = mean(time.times)
end

mean2 = zeros(length(N_list))
for j in 1:length(N_list)
    a = zeros(N)
    b = rand(N)
    c = rand(N)
    time = @benchmark begin
        a = @. (b*c)^2
    end
    mean2[j] = mean(time.times)

end

using Plots


plot(N_list, mean1, label="for", marker=:circle, xscale=:log10,  xlabel="N", ylabel="time")
plot!(N_list, mean2,  label="at once", marker=:circle, xscale=:log10)

スクリーンショット 2023-05-26 17.10.46.png

このように差は一目瞭然です。計算量がNに依存しないのは、多分内部ですごく上手く処理をしているのでしょう。

余計なメモリを確保しない(.=演算を使おう)

次に注意すべきは、適切にドットイコールを利用することです。早速ですが、次のコードを見てみましょう。

using Distributions, BenchmarkTools

N = 10^5
M = 10^5



function create_new_vec1(vec)
    vec_update = (vec .+ vec) ./ 2 .+ 1
    return vec_update
end

function create_new_vec2(vec)
    vec = (vec .+ vec) ./ 2 .+ 1
    return vec
end

function eq_dot(vec)
    vec .= (vec .+ vec) ./ 2 .+ 1
    return vec
end

vec = rand(M)

@time for i in 1:N
    vec = create_new_vec1(vec)
end


@time for i in 1:N
    vec = create_new_vec2(vec)
end


@time for i in 1:N
    vec = eq_dot(vec)
end

それぞれやりたいことは同じベクトルの繰り返し計算ですが、functionの内部の書き方が少し違います。その違いだけで、それぞれの計算時間は

 12.174541 seconds (715.62 k allocations: 74.527 GiB, 12.58% gc time, 0.46% compilation time)
 12.235685 seconds (715.64 k allocations: 74.527 GiB, 12.30% gc time, 0.38% compilation time)
  1.508056 seconds (501.45 k allocations: 16.032 MiB, 1.92% compilation time)

とeq_dot(vec)を使うやり方が10倍以上高速になっています。

実は、前二つの関数は、out-of-placeな計算(計算ごとに別のメモリを割り当てる計算)を行なっているのに対して、三つ目の関数ではin-placeな計算(すでに割り当てられている領域に値を更新していく)を行なっていることになっており、メモリ使用量が抑えられているため、高速になっているのです。これが「.=」の効果です。数値計算ではin-placeな計算で十分なことが多いため、不用意なメモリ確保を行なっていないかどうか十分に注意する必要があります。

他の例として、最初に挙げた次のプログラム

#ベクトルの各要素に繰り返し操作をしたい
N = 10^5

function f(x)
    return @. x^2 + 0.01
end


function g(x)
    for i in  1:N
        x = f(x)
    end
    return x
end

u = rand(N)
g(u)

も配列の割り当てに注目すれば高速化を期待できます。繰り返し使用されている関数f(x)の出力に別の領域を割り当てているので、.=を用いて次のように改善すると:

#ベクトルの各要素に繰り返し操作をしたい
N = 10^5

function f(x)
    return @. x^2 + 0.01
end

function f_revised(x)
    return x .= @. x^2 + 0.01
end

function g(x)
    for i in  1:N
        x = f(x)
    end
    return x
end

function g_revised(x)
    for i in  1:N
        x = f_revised(x)
    end
    return x
end

u = rand(N)
@time g(u)


u = rand(N)
@time g_revised(u)
 13.894223 seconds (767.29 k allocations: 74.530 GiB, 12.65% gc time, 0.47% compilation time)
  1.284470 seconds (463.64 k allocations: 14.332 MiB, 0.86% gc time, 3.44% compilation time)

と確かに10倍程度速くなりました。

その他の細かいTips

sumの書き方

sum関数は、メモリを消費しないように記述する必要があります。f.(a)という配列を計算してしまうより、sum(f,a)と記述することが最適です。

N = 10^7
a = rand(N)
f(x) =  x^2 + cos(x) 

@time sum([f(x) for x in 1:N])
@time sum(f(x) for x in 1:N)
@time sum(f.(a))
@time sum(f,a)
  0.168028 seconds (45.07 k allocations: 78.586 MiB, 4.00% gc time, 7.02% compilation time)
  0.160805 seconds (90.53 k allocations: 4.782 MiB, 6.48% compilation time)
  0.068647 seconds (80.62 k allocations: 80.305 MiB, 11.87% compilation time: 100% of which was recompilation)
  0.064405 seconds (34.36 k allocations: 1.788 MiB, 15.31% compilation time: 100% of which was recompilation)

割り算の切り捨てに関する使い分け

一般に割り算の整数への切り捨てに関しては、trunc関数よりdiv関数を用いるのが良いとされています。

s=rand()
t=rand()
@time for i in 1:N
    @. trunc((s+1)/t)
end
@time for i in 1:N
    @. div((s+1),t)
end
1.301641 seconds (10.00 M allocations: 259.384 MiB, 1.17% gc time)
0.824829 seconds (8.00 M allocations: 198.349 MiB, 1.38% gc time)

しかし、配列に関して同様のことを行なった際、原因は不明ながら、その速度関係は逆転しました。

x=rand(N)
y=rand(N)
@time @. trunc(x/y)
@time @. div(x,y)
0.004294 seconds (6 allocations: 7.630 MiB)
0.024367 seconds (4 allocations: 7.630 MiB, 13.27% gc time)

このように公式の関数が複数ある際は小さなモデルで計算時間を計算して、事前にどれを使うのが適切か調べるのが良いと思われます。

参考文献

Julia公式Tips集には、さらにJuliaの数値計算について必要な情報が詰まっています。主な話題はこちらのページから仕入れました。

8
7
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?