Juliaでの自動微分について調べてみる 前半の続きです。ここでは、自力で自動微分を実装してみます。その後、ChainRulesCoreを使って、既存のパッケージに自分の独自の型に関する自動微分を追加してみようと思います。
前半のおさらい
ある関数$f(x)$の微分を計算したい、とします。例えば、
a(x) = \sin(x) \\
b(a) = 0.2 + a \\
c(b) = \sqrt{1+b} \\
f(x) = c(b(a(x)))
function sample_func(x)
a = sin(x)
b = 0.2 + a
c = sqrt(1+b)
return c
end
のように定義された関数を微分することを考えます。この微分はReverse modeでは
\frac{\partial c}{\partial c} =1 \\
\frac{\partial c}{\partial b} = \frac{\partial c}{\partial b} \frac{\partial c}{\partial c} \\
\frac{\partial c}{\partial a} = \frac{\partial b}{\partial a} \frac{\partial c}{\partial b} \\
\frac{\partial c}{\partial x} = \frac{\partial a}{\partial x} \frac{\partial c}{\partial a} \\
\frac{\partial f}{\partial x} = \frac{\partial c}{\partial x}
によって計算可能です。前回定義したpullback関数:
\bar{x} = \frac{\partial c}{\partial x} = \frac{\partial y}{\partial x} \frac{\partial c}{\partial y} = \frac{\partial y}{\partial x} \bar{y} \equiv B_{y}(\bar{y})
を用いると、
\bar{c} =1 \\
\bar{b}= \frac{\partial c}{\partial b} \bar{c} = B_{c}(\bar{c})\\
\bar{a} = \frac{\partial b}{\partial a} \bar{b} = B_{b}(\bar{b}) \\
\bar{x} = \frac{\partial a}{\partial x} \bar{a} = B_{a}(\bar{a}) \\
\frac{\partial f}{\partial x} = \bar{x}
となります。ここで、
B_{a}(\bar{a}) = \cos(x) \bar{a} \\
B_{b}(\bar{b}) = \bar{b} \\
B_{c}(\bar{c}) = \frac{\bar{c}}{2 \sqrt{1+b}}
ですので、$\bar{c}=1$から順番に計算していけば微分が計算できます。
実装
それでは、pullback関数を実装してみます。Juliaでは、引数にある種類の関数が入ったときの挙動を定義することができまして、
function const_pullback(::typeof(sin),x)
pullback(ybar) = cos(x)*ybar
y = sin(x)
return y,pullback
end
とすると、第一引数にsin
が入った時にこの関数が呼ばれます。なお、返り値の2番目は関数です。他のpullback関数も同様に実装すると、
function const_pullback(::typeof(+),a,b)
pullback(ybar) = (ybar,ybar)
y = a + b
return y,pullback
end
function const_pullback(::typeof(sqrt),a)
pullback(ybar) = 0.5*ybar/sqrt(a)
y = sqrt(a)
return y,pullback
end
となります。これらを使って、特定の値$x$が入ってきた時のpullback関数を
x = 3
a,a_pullback = const_pullback(sin,x)
b,b_pullback = const_pullback(+,0.2,a)
c,c_pullback = const_pullback(sqrt,1+b)
println(c)
println(sample_func(x))
と定義します。最後に得られたc
とsample_func(x)
は等しくなります。
これらの関数を用いて、
cbar = 1 #∂c/∂c
bbar = c_pullback(cbar) #∂c/∂b = ∂c/∂b ∂c/∂c
_,abar = b_pullback(bbar) #∂c/∂a = ∂b/∂a ∂c/∂b
xbar = a_pullback(abar) #∂c/∂x = ∂a/∂x ∂c/∂a
println(xbar)
println(sample_func'(x))
とすると、xbar
は微分の値sample_func'(x)
が計算されています。これで複雑な関数の微分が計算できました。
仕組み
実際の自動微分パッケージの正確な仕組みについては理解しているとは言えませんが(詳しくは前半で紹介した論文を見てください)、Juliaは書かれたコードを解析してLLVMに直すのと同じように、書かれたコードからsin関数やその他を見つけたらpullback関数を考える、みたいな感じで自動微分を実行していると思われます。
Zygoteへの実装
問題設定
それでは、自動微分パッケージのZygoteに自分のオリジナルの関数の微分を追加してみましょう。
関数としては、
S(U_1(t),U_2(t)) = {\rm tr} \: (f(U_1(t),U_2(t))) \\
f(U_1(t),U_2(t)) = g(U_1(t),U_2(t)) + 10 \\
g(U_1(t),U_2(t)) = U_1(t)*U_2(t) \\
U_1(t) = {\cal F}(\cos (t)) \\
U_2(t) = {\cal F}(\exp (t))
とします。ここで、
{\rm tr} \: {\cal F}(x) = x \\
{\cal F}(x)*{\cal F}(y) = {\cal F}(x*y) \\
{\cal F}(x) + a = {\cal F}(x+a)
と定義します。関数$S$は全てを代入すると、
S(U_1(t),U_2(t)) = {\rm tr} \: ({\cal F}(\cos (t))*{\cal F}(\exp (t))+10) = \cos (t) \exp (t)
となりますから、その$t$微分は
\frac{\partial S}{\partial t} = - \sin (t) \exp (t) + \cos(t) \exp (t)
となります。この$t$微分を自動微分で計算することとします。
関数の評価
まず、関数$S$を計算するコードを書いておきます。途中で${\cal F}$というものがありますから、これを構造体として、
using LinearAlgebra
struct Field
A::Float64
end
function Base.:*(a::Field,b::Field)
Field(a.A*b.A)
end
function Base.:+(a::Field,b::T) where T <: Real
Field(a.A+b)
end
function LinearAlgebra.tr(a::Field)
return a.A
end
を定義しておきます。これで和と積とtrが定義できていますので、
calc_S(f) = tr(f)
calc_f(g) = g + 10
calc_g(U1,U2) = U1*U2
calc_U(x) = Field(x)
calc_x1(t) = exp(t)
calc_x2(t) = cos(t)
と関数を定義しておけば、
function calc_St(t)
U1 = calc_U(calc_x1(t))
U2 = calc_U(calc_x2(t))
g = calc_g(U1,U2)
f = calc_f(g)
S = calc_S(f)
return S
end
で関数$S$を計算できます。
pullback関数の実装
ChainRulesCoreを使うと、rruleというものでpullback関数を実装しておけば、Zygoteを使って微分することができるようになります。詳しくはこちらを参考にしてください。
まず、calc_S
はfが引数なので、
function ChainRulesCore.rrule(::typeof(calc_S),f::Field)
y = calc_S(f)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar
return sbar,fbar
end
return y, pullback
end
となります。calc_f
はgが引数なので
function ChainRulesCore.rrule(::typeof(calc_f),g::Field)
y = calc_f(g)
function pullback(ybar)
fbar = NoTangent()
gbar =Tangent{Field}(; A=ybar)
return fbar,gbar
end
return y, pullback
end
となります。calc_g
はU1
とU2
が引数なので、
function ChainRulesCore.rrule(::typeof(calc_g),U1::Field,U2::Field)
y = calc_g(U1,U2)
function pullback(ybar)
fbar = NoTangent()
dU1 =Tangent{Field}(; A=ybar*U2.A)
dU2 =Tangent{Field}(; A=ybar*U1.A)
return fbar,dU1,dU2
end
return y, pullback
end
となります。
次に、構造体Field
を生成する関数calc_U
の引数はx
として、
function ChainRulesCore.rrule(::typeof(calc_U),x::T) where T<: Real
y = calc_U(x)
function pullback(ybar)
fbar = NoTangent()
dx = ybar
return fbar,dx
end
return y, pullback
end
とします。最後に、calc_x1
およびcalc_x2
に関してpullbackを定義して、
function ChainRulesCore.rrule(::typeof(calc_x1),x::T) where T<: Real
y = calc_x1(x)
function pullback(ybar)
fbar = NoTangent()
dx = ybar*y
return fbar,dx
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(calc_x2),x::T) where T<: Real
y = calc_x2(x)
function pullback(ybar)
fbar = NoTangent()
dx = ybar*(-sin(x))
return fbar,dx
end
return y, pullback
end
を定義することで、t
を引数とする関数calc_St(t)
で使われている全ての関数に関してpullback関数を定義することができました。あとは、
function test()
t = 0.4
S = calc_St(t)
println(S)
#return
dt = 1e-5
dS = (calc_St(t+dt)-calc_St(t))/dt
println("numerical result: ",dS)
println("analytical result: ",(-sin(t)+cos(t))*exp(t))
dS_AD = calc_St'(t)
println("AD result: ",calc_St'(t))
end
test()
のようにして、calc_St'(t)
とすることで、Zygoteで関数$S$の$t$微分が計算できるようになります。
まとめ
このように、途中に独自型を定義してあっても、ちゃんと微分を計算しきることができました。