Juliaで機械学習をする際、Flux.jlが使われることが多いと思いますが、これにはZygote.jlによる自動微分が使われています。Zygote.jlは自動微分が可能ですが、時々、自分が作った関数が自動微分できない場合が出てくるかと思います。
問題設定
using Zygote
function calc_Heff(s,J)
_,Nx = size(s)
num = length(J)-1
H = 0.0
Si = zeros(3)
Sj = zeros(3)
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
H += -J[n]*Si[k]*Sj[k]
end
end
end
return H + J[end]
end
のような適当な関数
H(s,J) = \sum_{i=1}^{N_x} \sum_{l=1}^1 J_l \sum_k (s_{k,i} s_{k,i+l} + s_{k,i} s_{k,i-l}) + J_0
を考えます。ここで、$s$は$3 \times N_x$の配列です。そして、この関数を使った適当な関数
E(s,J) = H(s,J)^2
というものを考えます。
この関数のsに関する微分と
\frac{\partial E}{\partial s_{k,i}}
Jに関する微分
\frac{\partial E}{\partial J_l}
を自動微分したいとします。
自動微分するためにZygote.jlを使って、
function main()
Lx = 10
s = rand(3,Lx)
J = rand(3)
Heff = calc_Heff(s,J)
println(Heff)
f(J) = calc_Heff(s,J)
grad = gradient(f,J)[1]
d = 1e-4
for i=1:length(J)
println("i = $i")
Jp = deepcopy(J)
Jp[i] += d
dHdJ = (f(Jp)-f(J))/d
println("auto ",grad[i])
println("numerical: ",dHdJ)
end
fs(s) = calc_Heff(s,J)^2
grad = gradient(fs,s)[1]
for i=1:length(s)
println("i = $i")
sp = deepcopy(s)
sp[i] += d
dHds = (fs(sp)-fs(s))/d
println("numerical: ",dHds)
println("auto ",grad[i])
end
end
main()
を実行してみると、
ERROR: LoadError: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
というエラーが出てしまいます。
これは、配列の要素を変更するようなものはZygote.jlでは微分できない、ということを言っています。複雑な関数を考えたときに、配列の要素にアクセスすることはそれなりにあり得るので、このエラーはそれなりによく遭遇するかもしれません。
今回はこの問題をrruleを使って解決する方法について述べます。
rrule
rruleというものは、Zygote.jlでの自動微分で使われている関数です。詳しくは、Juliaでの自動微分について調べてみる 前半を参照してください。
しかし、rruleは少しわかりにくいかもしれません。結局何をすればいいのだろう、と困ってしまうかもしれません。そこで、簡単な処方箋について、ここで述べます。
ある関数$H(s,J)$があったとします。この時、rruleではpullbackというものを定義します。これは、2変数の場合には、
B_{kl}(\frac{\partial L}{\partial H}) \equiv \frac{\partial L}{\partial H} \frac{\partial H}{\partial s_{ki}} \\
B_{l}(\frac{\partial L}{\partial H}) \equiv \frac{\partial L}{\partial H} \frac{\partial H}{\partial J_l}
という二つの関数を計算するような関数を作れば良い、ということになります。ここで、$B_{kl}(\frac{\partial L}{\partial H})$の$\frac{\partial L}{\partial H}$は引数としてZygoteからもらってくるものだということに注意してください。もし$H$が配列であれば、この$\frac{\partial L}{\partial H}$も同じサイズの配列です。
これでもまだわかりにくいと思いますので、上の例の場合のrrulesを示します。
function ChainRulesCore.rrule(::typeof(calc_Heff),s,J)
y = calc_Heff(s,J)
_,Nx = size(s)
function pullback(ybar)
sbar = NoTangent()
num = length(J)-1
Si = zeros(3)
Sj = zeros(3)
dHdJ = zero(J)
dHdJ[end] = 1
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
dHdJ[n] += -Si[k]*Sj[k]
end
end
end
dHds = zero(s)
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
dHds[k,ix] += -J[n]*Sj[k]
dHds[k,jx] += -J[n]*Si[k]
end
end
end
return sbar,dHds*ybar,dHdJ*ybar
end
return y,pullback
end
ここで、引数ybar
は$\frac{\partial L}{\partial H}$のことです。今は、$\frac{\partial H}{\partial J_l}$は
\frac{\partial H}{\partial J_l} = \sum_{i=1}^{N_x} \sum_k (s_{k,i} s_{k,i+l} + s_{k,i} s_{k,i-l})
みたいな感じですね。そして、$\frac{\partial H}{\partial s_{ki}} $は
\frac{\partial H}{\partial s_{ki}} = \sum_{l=1}^1 J_l (s_{k,i+l} + s_{k,i-l}) + \sum_{l=1}^1 J_l(s_{k,i-l} + s_{k,i+l} )
のような感じです。
どんな複雑な関数でも、微分を連鎖律で計算する際にこのpullbackが必要になりますから、この部分さえ書いておけばちゃんと自動微分ができるようになります。
このように、rruleさえ書いておけば、上のようなエラーで怒られることはなくなります。
なお、このpullbackの利点ですが、$H$がスカラーではなく配列の場合には
B_{kl}(\frac{\partial L}{\partial H}) \equiv \sum_{nm} \frac{\partial L}{\partial H_{nm}} \frac{\partial H_{nm}}{\partial s_{ki}}
のように、配列の足の和を取る形で定義されている点です。このように定義されているおかげで、$H$が配列でもその微分がテンソルになる煩雑さから解放されています。もし$ \frac{\partial H_{nm}}{\partial s_{ki}} $をそのまま使う場合には、添字が四つの量を計算する必要があり、煩雑です。
全体のコード
全体のコードは
using Zygote
using ChainRulesCore
function calc_Heff(s,J)
_,Nx = size(s)
num = length(J)-1
H = 0.0
Si = zeros(3)
Sj = zeros(3)
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
H += -J[n]*Si[k]*Sj[k]
end
end
end
return H + J[end]
end
function ChainRulesCore.rrule(::typeof(calc_Heff),s,J)
y = calc_Heff(s,J)
_,Nx = size(s)
function pullback(ybar)
sbar = NoTangent()
num = length(J)-1
Si = zeros(3)
Sj = zeros(3)
ydHdJ = @thunk (begin
dHdJ = zero(J)
dHdJ[end] = 1
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
dHdJ[n] += -Si[k]*Sj[k]
end
end
end
dHdJ*ybar
end)
ydHds = @thunk (begin
dHds = zero(s)
for ix=1:Nx
for k=1:3
Si[k] = s[k,ix]
end
for dx = -2:2
jx = ix + dx
jx += ifelse(jx > Nx,-Nx,0)
jx += ifelse(jx < 1,Nx,0)
d = dx^2
if d == 1
n = 1
elseif d == 4
n = 2
else
n = -1
end
if n < 0
continue
end
if n > num
continue
end
for k=1:3
Sj[k] = s[k,jx]
end
for k=1:3
dHds[k,ix] += -J[n]*Sj[k]
dHds[k,jx] += -J[n]*Si[k]
end
end
end
dHds*ybar
end)
return sbar,ydHds,ydHdJ
end
return y,pullback
end
using Random
function main()
Random.seed!(123)
Lx = 10
s = rand(3,Lx)
J = rand(3)
Heff = calc_Heff(s,J)
println(Heff)
f(J) = calc_Heff(s,J)
grad = gradient(f,J)[1]
d = 1e-4
for i=1:length(J)
println("i = $i")
Jp = deepcopy(J)
Jp[i] += d
dHdJ = (f(Jp)-f(J))/d
println("auto ",grad[i])
println("numerical: ",dHdJ)
end
fs(s) = calc_Heff(s,J)^2
grad = gradient(fs,s)[1]
for i=1:length(s)
println("i = $i")
sp = deepcopy(s)
sp[i] += d
dHds = (fs(sp)-fs(s))/d
println("numerical: ",dHds)
println("auto ",grad[i])
end
end
main()
です。このコードは数値微分と自動微分を比較しています。なお、ここでは@thunk
というマクロを使っていますが、これは、pullbackが必要になった時だけ計算するようにするマクロです。
出力結果は、
-9.824735712364678
i = 1
auto -10.635164450712788
numerical: -10.635164450683021
i = 2
auto -10.330315751885076
numerical: -10.330315751900798
i = 3
auto 1.0
numerical: 0.9999999999976694
i = 1
numerical: 38.29405817171505
auto 38.29367837394289
i = 2
numerical: 44.17748724435455
auto 44.176981780513046
i = 3
numerical: 8.915029628866478
auto 8.915009044798328
i = 4
numerical: 15.633044094727211
auto 15.632980798038071
i = 5
numerical: 42.483074651755715
auto 42.48260721781677
i = 6
numerical: 7.969837641041977
auto 7.96982118981297
i = 7
numerical: 35.52387859855344
auto 35.523551762101604
i = 8
numerical: 31.212851514652584
auto 31.212599191047175
i = 9
numerical: 61.55327377115327
auto 61.55229250541251
i = 10
numerical: 54.17398578288157
auto 54.17322568869327
i = 11
numerical: 40.38850357147794
auto 40.38808109315797
i = 12
numerical: 21.520959319900612
auto 21.520839365288218
i = 13
numerical: 29.102309353419287
auto 29.102089999312557
i = 14
numerical: 40.72471887525353
auto 40.72428933392984
i = 15
numerical: 36.57395355759263
auto 36.57360711314064
i = 16
numerical: 18.185129621457463
auto 18.185043971736278
i = 17
numerical: 31.02623367269075
auto 31.0259843576836
i = 18
numerical: 43.89418157131786
auto 43.89368256927053
i = 19
numerical: 25.789037009502636
auto 25.78886475824141
i = 20
numerical: 26.3545590752301
auto 26.354379186735294
i = 21
numerical: 40.96929522773962
auto 40.968860511792855
i = 22
numerical: 44.313672602669385
auto 44.3131640181634
i = 23
numerical: 43.01551627790445
auto 43.0150370528497
i = 24
numerical: 10.140503042919136
auto 10.14047640982912
i = 25
numerical: 44.821403888590794
auto 44.82088358290219
i = 26
numerical: 22.80005100246285
auto 22.799916365669308
i = 27
numerical: 61.61837171049456
auto 61.61738836767953
i = 28
numerical: 35.37043100024562
auto 35.370106981444884
i = 29
numerical: 15.997447855085056
auto 15.997381573718716
i = 30
numerical: 48.27905533019816
auto 48.278451652587734
となります。ちゃんと自動微分ができていることがわかりますね。