LoginSignup
15
12

More than 1 year has passed since last update.

Juliaでの自動微分について調べてみる 後半

Posted at

Juliaでの自動微分について調べてみる 前半の続きです。ここでは、自力で自動微分を実装してみます。その後、ChainRulesCoreを使って、既存のパッケージに自分の独自の型に関する自動微分を追加してみようと思います。

前半のおさらい

ある関数$f(x)$の微分を計算したい、とします。例えば、

    a(x) = \sin(x) \\
    b(a) = 0.2 + a \\
    c(b) = \sqrt{1+b} \\
    f(x) = c(b(a(x)))
function sample_func(x)
    a = sin(x)
    b = 0.2 + a
    c = sqrt(1+b)
    return c
end

のように定義された関数を微分することを考えます。この微分はReverse modeでは

\frac{\partial c}{\partial c} =1 \\
\frac{\partial c}{\partial b} = \frac{\partial c}{\partial b} \frac{\partial c}{\partial c} \\
\frac{\partial c}{\partial a} = \frac{\partial b}{\partial a} \frac{\partial c}{\partial b} \\
\frac{\partial c}{\partial x} = \frac{\partial a}{\partial x} \frac{\partial c}{\partial a} \\
\frac{\partial f}{\partial x} = \frac{\partial c}{\partial x} 

によって計算可能です。前回定義したpullback関数:

 \bar{x} = \frac{\partial c}{\partial x} = \frac{\partial y}{\partial x}  \frac{\partial c}{\partial y} = \frac{\partial y}{\partial x}  \bar{y} \equiv B_{y}(\bar{y})

を用いると、

\bar{c} =1 \\
\bar{b}= \frac{\partial c}{\partial b} \bar{c} = B_{c}(\bar{c})\\
\bar{a} = \frac{\partial b}{\partial a} \bar{b} = B_{b}(\bar{b}) \\
\bar{x} = \frac{\partial a}{\partial x} \bar{a} = B_{a}(\bar{a}) \\
\frac{\partial f}{\partial x} = \bar{x}

となります。ここで、

B_{a}(\bar{a}) = \cos(x) \bar{a} \\
B_{b}(\bar{b}) = \bar{b} \\ 
B_{c}(\bar{c}) = \frac{\bar{c}}{2 \sqrt{1+b}}

ですので、$\bar{c}=1$から順番に計算していけば微分が計算できます。

実装

それでは、pullback関数を実装してみます。Juliaでは、引数にある種類の関数が入ったときの挙動を定義することができまして、

function const_pullback(::typeof(sin),x)
    pullback(ybar) = cos(x)*ybar
    y = sin(x)
    return y,pullback
end

とすると、第一引数にsinが入った時にこの関数が呼ばれます。なお、返り値の2番目は関数です。他のpullback関数も同様に実装すると、

function const_pullback(::typeof(+),a,b)
    pullback(ybar) = (ybar,ybar)
    y = a + b
    return y,pullback
end

function const_pullback(::typeof(sqrt),a)
    pullback(ybar) = 0.5*ybar/sqrt(a)
    y = sqrt(a)
    return y,pullback
end

となります。これらを使って、特定の値$x$が入ってきた時のpullback関数を

x = 3
a,a_pullback = const_pullback(sin,x)
b,b_pullback = const_pullback(+,0.2,a)
c,c_pullback = const_pullback(sqrt,1+b)
println(c)
println(sample_func(x))

と定義します。最後に得られたcsample_func(x)は等しくなります。
これらの関数を用いて、

cbar = 1 #∂c/∂c
bbar = c_pullback(cbar) #∂c/∂b = ∂c/∂b ∂c/∂c
_,abar = b_pullback(bbar) #∂c/∂a = ∂b/∂a ∂c/∂b
xbar = a_pullback(abar) #∂c/∂x = ∂a/∂x ∂c/∂a
println(xbar)
println(sample_func'(x))

とすると、xbarは微分の値sample_func'(x)が計算されています。これで複雑な関数の微分が計算できました。

仕組み

実際の自動微分パッケージの正確な仕組みについては理解しているとは言えませんが(詳しくは前半で紹介した論文を見てください)、Juliaは書かれたコードを解析してLLVMに直すのと同じように、書かれたコードからsin関数やその他を見つけたらpullback関数を考える、みたいな感じで自動微分を実行していると思われます。

Zygoteへの実装

問題設定

それでは、自動微分パッケージのZygoteに自分のオリジナルの関数の微分を追加してみましょう。
関数としては、

S(U_1(t),U_2(t)) = {\rm tr} \: (f(U_1(t),U_2(t))) \\
f(U_1(t),U_2(t)) = g(U_1(t),U_2(t)) + 10 \\
g(U_1(t),U_2(t)) = U_1(t)*U_2(t) \\
U_1(t) = {\cal F}(\cos (t)) \\
U_2(t)  = {\cal F}(\exp (t))

とします。ここで、

{\rm tr} \: {\cal F}(x) = x \\
{\cal F}(x)*{\cal F}(y) = {\cal F}(x*y) \\
 {\cal F}(x) + a =   {\cal F}(x+a) 

と定義します。関数$S$は全てを代入すると、

S(U_1(t),U_2(t)) = {\rm tr} \: ({\cal F}(\cos (t))*{\cal F}(\exp (t))+10) = \cos (t) \exp (t)

となりますから、その$t$微分は

\frac{\partial S}{\partial t} = - \sin (t) \exp (t) + \cos(t) \exp (t)

となります。この$t$微分を自動微分で計算することとします。

関数の評価

まず、関数$S$を計算するコードを書いておきます。途中で${\cal F}$というものがありますから、これを構造体として、

using LinearAlgebra
struct Field
    A::Float64
end

function Base.:*(a::Field,b::Field)
    Field(a.A*b.A)
end

function Base.:+(a::Field,b::T) where T <: Real
    Field(a.A+b)
end

function LinearAlgebra.tr(a::Field) 
    return a.A
end

を定義しておきます。これで和と積とtrが定義できていますので、

calc_S(f) = tr(f)
calc_f(g) = g + 10
calc_g(U1,U2) = U1*U2
calc_U(x) = Field(x)
calc_x1(t) = exp(t)
calc_x2(t) = cos(t)

と関数を定義しておけば、

function calc_St(t)
    U1 = calc_U(calc_x1(t))
    U2 = calc_U(calc_x2(t))
    g = calc_g(U1,U2)
    f = calc_f(g) 
    S = calc_S(f)
    return S
end

で関数$S$を計算できます。

pullback関数の実装

ChainRulesCoreを使うと、rruleというものでpullback関数を実装しておけば、Zygoteを使って微分することができるようになります。詳しくはこちらを参考にしてください。

まず、calc_Sはfが引数なので、

function ChainRulesCore.rrule(::typeof(calc_S),f::Field) 
    y = calc_S(f)

    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar
        return sbar,fbar
    end
    return y, pullback
end

となります。calc_fはgが引数なので

function ChainRulesCore.rrule(::typeof(calc_f),g::Field) 
    y = calc_f(g)
    function pullback(ybar)
        fbar = NoTangent()
        gbar =Tangent{Field}(; A=ybar)
        return fbar,gbar
    end
    return y, pullback
end

となります。calc_gU1U2が引数なので、

function ChainRulesCore.rrule(::typeof(calc_g),U1::Field,U2::Field) 
    y = calc_g(U1,U2)
    function pullback(ybar)
        fbar = NoTangent()
        dU1 =Tangent{Field}(; A=ybar*U2.A)
        dU2 =Tangent{Field}(; A=ybar*U1.A)
        return fbar,dU1,dU2
    end
    return y, pullback
end

となります。

次に、構造体Fieldを生成する関数calc_Uの引数はxとして、

function ChainRulesCore.rrule(::typeof(calc_U),x::T) where T<: Real 
    y = calc_U(x)
    function pullback(ybar)
        fbar = NoTangent()
        dx = ybar
        return fbar,dx
    end
    return y, pullback
end

とします。最後に、calc_x1およびcalc_x2に関してpullbackを定義して、

function ChainRulesCore.rrule(::typeof(calc_x1),x::T) where T<: Real 
    y = calc_x1(x)
    function pullback(ybar)
        fbar = NoTangent()
        dx = ybar*y
        return fbar,dx
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(calc_x2),x::T) where T<: Real 
    y = calc_x2(x)
    function pullback(ybar)
        fbar = NoTangent()
        dx = ybar*(-sin(x))
        return fbar,dx
    end
    return y, pullback
end

を定義することで、tを引数とする関数calc_St(t)で使われている全ての関数に関してpullback関数を定義することができました。あとは、

function test()
    t = 0.4
    S = calc_St(t)
    println(S)
    #return

    dt = 1e-5
    dS = (calc_St(t+dt)-calc_St(t))/dt
    println("numerical result: ",dS)
    println("analytical result: ",(-sin(t)+cos(t))*exp(t))
    dS_AD = calc_St'(t)
    println("AD result: ",calc_St'(t))
end
test()

のようにして、calc_St'(t)とすることで、Zygoteで関数$S$の$t$微分が計算できるようになります。

まとめ

このように、途中に独自型を定義してあっても、ちゃんと微分を計算しきることができました。

15
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
12