Edited at

Juliaで数値計算 その2:コードサンプル〜わかりやすい書き方編〜

More than 1 year has passed since last update.

Juliaで数値計算 その1:コードサンプル〜基本的計算編〜

https://qiita.com/cometscome_phys/items/31d0b811345a3e12fcef

に引き続いて、FortranやCからやってきた人のためにJuliaの書き方を紹介する。

この記事はその2である。

FortranやCで数値計算をしている場合、コードを拡張していくたびに引数が増えていったりして困ったことが一度はあると思う。そのようなとき、FORTRANであればcommon文、Fortranであればmoduleなどを使って、引数を減らしたりする。しかし、common文にせよmodule文にせよ、変数や定数をグローバルに保持しているので、コードのある場所において、何を参照しているのかがわかりにくくなってしまう。Juliaでの一つのやり方をここに紹介する。


moduleとstruct


module

moduleという機能がある。これは、Fortranにも似たものがある。


module.jl

module Testmodule

a = 3.0
function testfunc(x)
b = x + 1
println("b = $b")
return b
end
end

このようにmoduleを定義しておくと、パッケージを使うときのように、usingを使ってアクセスすることができる。たとえば、


module.jl

using .Testmodule

println(Testmodule.a)
x = 2.0
b1 = Testmodule.testfunc(x)
println("b1 = $b1")

とすると、Testmodule内の変数aにアクセスできる。また、関数testfuncも呼ぶことができる。using .Testmoduleのドットは、moduleのコード内の位置を表していて、.一個だと同じインテンド内にあるmoduleを呼ぶことになる。

毎回Testmodule.をつけるのが面倒であれば、


module.jl

module Testmodule2

export testfunc
a = 3.0
function testfunc(x)
b2 = x + 1
println("b2 = $b2")
return b2
end

function testfunc2(x)
b22 = x + 1
println("b22 = $b22")
return b22
end
end


とexportをつけておくと、exportに記述した関数や変数は、Testmodule2.としなくても呼ぶことができる:


module.jl

using .Testmodule2

println(Testmodule2.a)
x = 2.0
b1 = testfunc(x)
println("b1 = $b1")

ただし、のちのちの読みやすさのためには、Testmodule2.とドットをつけたほうがよい気がする。

コード中での機能ごとに異なるmoduleとして書いておけば、コードが何をしているかわかりやすくなるし、そのmoduleを別のコードへと移植することも容易となる。


struct

Juliaには、Fortranで言う所の構造体と似たものがある。Juliaでは、変更不可能なものをstruct、変更が可能なものをmutable structと呼んでいる。


変更不可能なもの:struct

変更不可能なものはstructを使う。ここではっきりと型宣言しておくと、速く実行できる。

たとえば、


struct.jl

structParameter

a::Float64
b::Float64
end

param1 = Parameter(2.3,5.2)
println(param1)
param2 = Parameter(2.1,4.3)
println(param2)


などと定義する。いま、aとbは倍精度実数であると定義した。Fortranで言うところの構造体と思えばよい。ただし、このstructは変更できない。変更したい場合には、mutable structを使う。


変更可能なもの:mutable struct

変更が可能なものはmutable structで定義する。たとえば、


struct.jl

mutable structmutable Spin

φ::Float64
z::Float64
end

とする。これの面白いところは、structに対しても多重ディスパッチを使うことができて、Spinというstructに対する+演算子などを定義できるところにある。多重ディスパッチなので、まず+が定義されているBaseをimportして、そこに型がSpinであったときの+演算子の挙動を書く。それにより、Spinという型同士の足し算を定義できる。つまり、


struct.jl

import Base:+

+(a::Spin,b::Spin) = Spin(a.φ-b.φ,a.z+b.z)
spin1 = Spin(1.2,0.2)
spin2 = Spin(0.2,0.8)
spin3 = spin1 + spin2
println(spin3)

のようにすれば、φ同士は引いて、z同士は足す、という演算を定義することができる。

また、引数の型をSpinとしたときにのみ動くfunctionを定義して、


struct.jl

function calc_S(spin1::Spin)

S = zeros(Float64,3)
S[1] = cos(spin1.φ)*sqrt(1-spin1.z^2)
S[2] = sin(spin1.φ)*sqrt(1-spin1.z^2)
S[3] = spin1.z
return S
end

S1 = calc_S(spin1)
println(S1)


とすれば、古典スピンの型としてSpinを用意して、適宜Sx,Sy,Szを計算できるようになる。

このような書き方の利点は、φやzを直接いじらずに、Spinというものをいじっている形に書くことができるので、コードがわかりやすくデバッグしやすくなるという点である。

Spin型に関する演算のバグを取ることができれば、メインのコードではその部分にバグがないとして書くことができる。また、Spin型を拡張してみたい場合は、そこを拡張しておけば、メインのコードは一切変更しなくても済む。

なお、moduleとstructの名前は、最初の文字を大文字とするのが推奨されているそうだ。


わかりやすい書き方の例

以下のコードはJuliaを使った一つのわかりやすいと思われる書き方の例である。

コードであとあとパラメータを増やしたりしたいときにも対応するためには、structを引数としてもっておけばよい。structに新しいパラメータを足せば、それを引数としたすべてのfunctionでそのパラメータを利用可能にすることができる。つまり、メインコードをいじらなくてすむ。また、パラメータとして使うけれども、時々変えたいようなもの(たとえば温度とか)は、mutable structで作って引数としてもっておけばよい。つまり、structとmutable structの二つを引数としておけば、メインコードがかなりすっきりする。

以下にサンプルを載せる。

まず、moduleとして、使うパラメータや変数を載せたstructとmutable structを定義する。

パラメータを増やしたければここを増やせばよい。


sample.jl

module Test

structParam #変更しないパラメータ
a::Float64 #実数
b::Int64 #整数
c::ComplexF64 #複素数
vector_a::Array{Int64,1} #1次元配列。要素が整数
vector_b::Array{Float64,1} #1次元配列。要素が実数
vector_c::Array{ComplexF64,1} #1次元配列。要素が複素数
matrix_c::Array{ComplexF64,2} #2次元配列。要素が複素数
σ::Array{Array{ComplexF64,2},1} #1次元配列。要素が2次元配列
nT::Int64
Tmax ::Float64
Tmin ::Float64
end

mutable structmutable Variables #変更可能なパラメータ
T::Float64 #実数
k::Array{Float64,1} #一次元配列。要素が実数
end
end


上では使いそうな型をいろいろ載せておいた。

これを使うには、


sample.jl

using .Test #自前のモジュールを読み込むときは、.をつける。これはmoduleの位置を表す。

a = 2.0
b = 4
c = 3.0 + 5im
vector_a = [3,2,1]
vector_b = [3.9,2.3,4.3]
vector_c = rand(Float64,3)+im*rand(Float64,3) #randは-1から1の乱数を作る。
matrix_c = rand(ComplexF64,3,3)
#パウリ行列
σx = [0 1
1 0]
σy = [0.0 -im
im 0.0]
σz = [1.0 0.0
0.0 -1.0]
σ = [σx,σy,σz]
nT = 12
Tmax = 1.0
Tmin = 0.1

param1 = Test.Param(a,b,c,vector_a,vector_b,vector_c,matrix_c,σ,nT,Tmax,Tmin)

T = 0.1
k = [π,0.5π,0.25π]
var1 = Test.Variables(T,k)


とすればよい。サンプルとして、パウリ行列を要素とした一次元配列σを定義しておいた。

そして、これらを引数とした関数を


sample.jl

function tempdep(param,var)

dT = (param.Tmax-param.Tmin)/(param.nT-1)
for iT = 1:param.nT
var.T =(iT-1)*dT + param.Tmin
g = calc_g(param,var)
println("T = $(var.T), g = $g")
end
end

function calc_g(param,var)
g = var.T*param.a
return g
end


とすれば、


sample.jl

x = 3

tempdep(param1,var1)

で実行できる。関数の中で参考にしているのはつねにparamなので、デバッグをするときに値を参照すべきなのはparamとvarであることがわかる。