4
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Flux.jlやZygote.jlで自動微分:rruleの書き方

Last updated at Posted at 2023-03-10

Juliaで機械学習をする際、Flux.jlが使われることが多いと思いますが、これにはZygote.jlによる自動微分が使われています。Zygote.jlは自動微分が可能ですが、時々、自分が作った関数が自動微分できない場合が出てくるかと思います。

問題設定

using Zygote

function calc_Heff(s,J)
    _,Nx = size(s)
    num = length(J)-1
    H = 0.0
    Si = zeros(3)
    Sj = zeros(3)
    for ix=1:Nx
        for k=1:3
            Si[k] = s[k,ix]
        end
        for dx = -2:2

            jx = ix + dx
            jx += ifelse(jx > Nx,-Nx,0)
            jx += ifelse(jx < 1,Nx,0)

            d = dx^2
            if d == 1
                n = 1
            elseif d == 4
                n = 2
            else 
                n = -1
            end
            if n < 0
                continue
            end
            if n > num
                continue
            end

            for k=1:3
                Sj[k] = s[k,jx]
            end

            for k=1:3
                H += -J[n]*Si[k]*Sj[k]
            end
        end
    end
    return H + J[end]
end

のような適当な関数

H(s,J) = \sum_{i=1}^{N_x} \sum_{l=1}^1 J_l \sum_k (s_{k,i} s_{k,i+l} + s_{k,i} s_{k,i-l}) + J_0

を考えます。ここで、$s$は$3 \times N_x$の配列です。そして、この関数を使った適当な関数

E(s,J) = H(s,J)^2

というものを考えます。

この関数のsに関する微分と

\frac{\partial E}{\partial s_{k,i}}

Jに関する微分

\frac{\partial E}{\partial J_l}

を自動微分したいとします。

自動微分するためにZygote.jlを使って、

function main()
    Lx = 10
    s = rand(3,Lx)
    J = rand(3)
    Heff = calc_Heff(s,J)
    println(Heff)

    f(J) = calc_Heff(s,J)

    grad = gradient(f,J)[1]
    
    d = 1e-4
    for i=1:length(J)
        println("i = $i")
        Jp = deepcopy(J)
        Jp[i] += d
        dHdJ = (f(Jp)-f(J))/d
        println("auto ",grad[i])
        println("numerical: ",dHdJ)
    end
        
        

    fs(s) = calc_Heff(s,J)^2 
    grad = gradient(fs,s)[1]

    for i=1:length(s)
        println("i = $i")
        sp = deepcopy(s)
        sp[i] += d
        dHds = (fs(sp)-fs(s))/d
        println("numerical: ",dHds)
    
        println("auto ",grad[i])
    end
    
end
main()

を実行してみると、

ERROR: LoadError: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

というエラーが出てしまいます。
これは、配列の要素を変更するようなものはZygote.jlでは微分できない、ということを言っています。複雑な関数を考えたときに、配列の要素にアクセスすることはそれなりにあり得るので、このエラーはそれなりによく遭遇するかもしれません。
今回はこの問題をrruleを使って解決する方法について述べます。

rrule

rruleというものは、Zygote.jlでの自動微分で使われている関数です。詳しくは、Juliaでの自動微分について調べてみる 前半を参照してください。

しかし、rruleは少しわかりにくいかもしれません。結局何をすればいいのだろう、と困ってしまうかもしれません。そこで、簡単な処方箋について、ここで述べます。
ある関数$H(s,J)$があったとします。この時、rruleではpullbackというものを定義します。これは、2変数の場合には、

B_{kl}(\frac{\partial L}{\partial H}) \equiv \frac{\partial L}{\partial H} \frac{\partial H}{\partial s_{ki}} \\
B_{l}(\frac{\partial L}{\partial H}) \equiv \frac{\partial L}{\partial H} \frac{\partial H}{\partial J_l}

という二つの関数を計算するような関数を作れば良い、ということになります。ここで、$B_{kl}(\frac{\partial L}{\partial H})$の$\frac{\partial L}{\partial H}$は引数としてZygoteからもらってくるものだということに注意してください。もし$H$が配列であれば、この$\frac{\partial L}{\partial H}$も同じサイズの配列です。

これでもまだわかりにくいと思いますので、上の例の場合のrrulesを示します。

function ChainRulesCore.rrule(::typeof(calc_Heff),s,J) 
    y = calc_Heff(s,J)
    _,Nx = size(s)
    function pullback(ybar)
        sbar = NoTangent()
        num = length(J)-1

        Si = zeros(3)
        Sj = zeros(3)
        dHdJ = zero(J)
        dHdJ[end] = 1
        for ix=1:Nx
            for k=1:3
                Si[k] = s[k,ix]
            end
            for dx = -2:2
                jx = ix + dx
                jx += ifelse(jx > Nx,-Nx,0)
                jx += ifelse(jx < 1,Nx,0)
    
                d = dx^2
                if d == 1
                    n = 1
                elseif d == 4
                    n = 2
                else 
                    n = -1
                end
                if n < 0
                    continue
                end
                if n > num
                    continue
                end

                for k=1:3
                    Sj[k] = s[k,jx]
                end

                for k=1:3
                    dHdJ[n] += -Si[k]*Sj[k]
                end

            end
        end

        dHds = zero(s)
        for ix=1:Nx
            for k=1:3
                Si[k] = s[k,ix]
            end
            for dx = -2:2
                jx = ix + dx
                jx += ifelse(jx > Nx,-Nx,0)
                jx += ifelse(jx < 1,Nx,0)
    
                d = dx^2
                if d == 1
                    n = 1
                elseif d == 4
                    n = 2
                else 
                    n = -1
                end
                if n < 0
                    continue
                end
                if n > num
                    continue
                end

                for k=1:3
                    Sj[k] = s[k,jx]
                end

                for k=1:3
                    dHds[k,ix] += -J[n]*Sj[k]
                    dHds[k,jx] += -J[n]*Si[k]
                end

            end
        end

        return sbar,dHds*ybar,dHdJ*ybar
                
    end
    return y,pullback
end

ここで、引数ybarは$\frac{\partial L}{\partial H}$のことです。今は、$\frac{\partial H}{\partial J_l}$は

\frac{\partial H}{\partial J_l} = \sum_{i=1}^{N_x} \sum_k (s_{k,i} s_{k,i+l} + s_{k,i} s_{k,i-l}) 

みたいな感じですね。そして、$\frac{\partial H}{\partial s_{ki}} $は

\frac{\partial H}{\partial s_{ki}} =  \sum_{l=1}^1 J_l  (s_{k,i+l} + s_{k,i-l}) + \sum_{l=1}^1 J_l(s_{k,i-l}  + s_{k,i+l} ) 

のような感じです。
どんな複雑な関数でも、微分を連鎖律で計算する際にこのpullbackが必要になりますから、この部分さえ書いておけばちゃんと自動微分ができるようになります。
このように、rruleさえ書いておけば、上のようなエラーで怒られることはなくなります。

なお、このpullbackの利点ですが、$H$がスカラーではなく配列の場合には

B_{kl}(\frac{\partial L}{\partial H}) \equiv \sum_{nm} \frac{\partial L}{\partial H_{nm}} \frac{\partial H_{nm}}{\partial s_{ki}} 

のように、配列の足の和を取る形で定義されている点です。このように定義されているおかげで、$H$が配列でもその微分がテンソルになる煩雑さから解放されています。もし$ \frac{\partial H_{nm}}{\partial s_{ki}} $をそのまま使う場合には、添字が四つの量を計算する必要があり、煩雑です。

全体のコード

全体のコードは


using Zygote
using ChainRulesCore

function calc_Heff(s,J)
    _,Nx = size(s)
    num = length(J)-1
    H = 0.0
    Si = zeros(3)
    Sj = zeros(3)
    for ix=1:Nx
        for k=1:3
            Si[k] = s[k,ix]
        end
        for dx = -2:2

            jx = ix + dx
            jx += ifelse(jx > Nx,-Nx,0)
            jx += ifelse(jx < 1,Nx,0)

            d = dx^2
            if d == 1
                n = 1
            elseif d == 4
                n = 2
            else 
                n = -1
            end
            if n < 0
                continue
            end
            if n > num
                continue
            end

            for k=1:3
                Sj[k] = s[k,jx]
            end

            for k=1:3
                H += -J[n]*Si[k]*Sj[k]
            end
        end
    end
    return H + J[end]
end



function ChainRulesCore.rrule(::typeof(calc_Heff),s,J) 
    y = calc_Heff(s,J)
    _,Nx = size(s)
    function pullback(ybar)
        sbar = NoTangent()
        num = length(J)-1

        Si = zeros(3)
        Sj = zeros(3)
        ydHdJ = @thunk (begin
            dHdJ = zero(J)
            dHdJ[end] = 1
            for ix=1:Nx
                for k=1:3
                    Si[k] = s[k,ix]
                end
                for dx = -2:2
                    jx = ix + dx
                    jx += ifelse(jx > Nx,-Nx,0)
                    jx += ifelse(jx < 1,Nx,0)
        
                    d = dx^2
                    if d == 1
                        n = 1
                    elseif d == 4
                        n = 2
                    else 
                        n = -1
                    end
                    if n < 0
                        continue
                    end
                    if n > num
                        continue
                    end

                    for k=1:3
                        Sj[k] = s[k,jx]
                    end

                    for k=1:3
                        dHdJ[n] += -Si[k]*Sj[k]
                    end

                end
            end
            dHdJ*ybar
        end)

        ydHds = @thunk (begin
            dHds = zero(s)
            for ix=1:Nx
                for k=1:3
                    Si[k] = s[k,ix]
                end
                for dx = -2:2
                    jx = ix + dx
                    jx += ifelse(jx > Nx,-Nx,0)
                    jx += ifelse(jx < 1,Nx,0)
        
                    d = dx^2
                    if d == 1
                        n = 1
                    elseif d == 4
                        n = 2
                    else 
                        n = -1
                    end
                    if n < 0
                        continue
                    end
                    if n > num
                        continue
                    end

                    for k=1:3
                        Sj[k] = s[k,jx]
                    end

                    for k=1:3
                        dHds[k,ix] += -J[n]*Sj[k]
                        dHds[k,jx] += -J[n]*Si[k]
                    end

                end
            end
            dHds*ybar
        end)

        return sbar,ydHds,ydHdJ
                
    end
    return y,pullback
end


using Random
function main()
    Random.seed!(123)
    Lx = 10
    s = rand(3,Lx)
    J = rand(3)
    Heff = calc_Heff(s,J)
    println(Heff)

    f(J) = calc_Heff(s,J)

    grad = gradient(f,J)[1]
    
    d = 1e-4
    for i=1:length(J)
        println("i = $i")
        Jp = deepcopy(J)
        Jp[i] += d
        dHdJ = (f(Jp)-f(J))/d
        println("auto ",grad[i])
        println("numerical: ",dHdJ)
    end
        
        

    fs(s) = calc_Heff(s,J)^2 
    grad = gradient(fs,s)[1]

    for i=1:length(s)
        println("i = $i")
        sp = deepcopy(s)
        sp[i] += d
        dHds = (fs(sp)-fs(s))/d
        println("numerical: ",dHds)
    
        println("auto ",grad[i])
    end
    
end
main()

です。このコードは数値微分と自動微分を比較しています。なお、ここでは@thunk というマクロを使っていますが、これは、pullbackが必要になった時だけ計算するようにするマクロです。

出力結果は、

-9.824735712364678
i = 1
auto -10.635164450712788
numerical: -10.635164450683021
i = 2
auto -10.330315751885076
numerical: -10.330315751900798
i = 3
auto 1.0
numerical: 0.9999999999976694
i = 1
numerical: 38.29405817171505
auto 38.29367837394289
i = 2
numerical: 44.17748724435455
auto 44.176981780513046
i = 3
numerical: 8.915029628866478
auto 8.915009044798328
i = 4
numerical: 15.633044094727211
auto 15.632980798038071
i = 5
numerical: 42.483074651755715
auto 42.48260721781677
i = 6
numerical: 7.969837641041977
auto 7.96982118981297
i = 7
numerical: 35.52387859855344
auto 35.523551762101604
i = 8
numerical: 31.212851514652584
auto 31.212599191047175
i = 9
numerical: 61.55327377115327
auto 61.55229250541251
i = 10
numerical: 54.17398578288157
auto 54.17322568869327
i = 11
numerical: 40.38850357147794
auto 40.38808109315797
i = 12
numerical: 21.520959319900612
auto 21.520839365288218
i = 13
numerical: 29.102309353419287
auto 29.102089999312557
i = 14
numerical: 40.72471887525353
auto 40.72428933392984
i = 15
numerical: 36.57395355759263
auto 36.57360711314064
i = 16
numerical: 18.185129621457463
auto 18.185043971736278
i = 17
numerical: 31.02623367269075
auto 31.0259843576836
i = 18
numerical: 43.89418157131786
auto 43.89368256927053
i = 19
numerical: 25.789037009502636
auto 25.78886475824141
i = 20
numerical: 26.3545590752301
auto 26.354379186735294
i = 21
numerical: 40.96929522773962
auto 40.968860511792855
i = 22
numerical: 44.313672602669385
auto 44.3131640181634
i = 23
numerical: 43.01551627790445
auto 43.0150370528497
i = 24
numerical: 10.140503042919136
auto 10.14047640982912
i = 25
numerical: 44.821403888590794
auto 44.82088358290219
i = 26
numerical: 22.80005100246285
auto 22.799916365669308
i = 27
numerical: 61.61837171049456
auto 61.61738836767953
i = 28
numerical: 35.37043100024562
auto 35.370106981444884
i = 29
numerical: 15.997447855085056
auto 15.997381573718716
i = 30
numerical: 48.27905533019816
auto 48.278451652587734

となります。ちゃんと自動微分ができていることがわかりますね。

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?