LoginSignup
5
7

Juliaでの自動微分とrruleの書き方について

Last updated at Posted at 2024-02-14

Juliaで自動微分を扱うと、自分が作った型に対して自動微分ができるのが非常に面白いです。以前Juliaでの自動微分について調べてみる 前半や、Juliaでの自動微分について調べてみる 後半で基本的なことを調べてきました。また、自分で定義した型で微分する方法として、自分で作った行列に対しての自動微分についてはJuliaでの自動微分を使って行列で微分してみるJuliaでの自動微分を使って行列を引数にした複雑な関数を行列で微分してみるで述べました。複素関数のウィルティンガー微分に関しては、Juliaでの自動微分を使って、ウィルティンガー微分してみるJuliaでの自動微分を使って、ウィルティンガー微分してみる:完全版で述べました。
どのような型を扱うにせよ、rruleというものが重要でした。しかし、rruleは少しわかりにくいので、これについて整理したいと思います。

また、rruleというものを使うことで、自動微分そのものについて、改めてまとめ直しています。

入力と出力について

まず、やりたいことは、何らかの変数での微分を

using Zygote
dfdx = gradient(f,x)[1]

のように計算することです。ここで、重要な条件として、

  • 関数fは実スカラー関数であること

があります。例えば、

\begin{align}
f(x) = 2x+3
\end{align}

という関数を微分するには、

using Zygote
function test1()
    f(x) = 2 * x + 3
    x = 2
    dfdx = gradient(f, x)[1]
    println(dfdx)
end
test1()

とします。出力は

2.0

です。

そして、$x$と$y$の二変数だった場合、

\begin{align}
f(x,y) = 2xy+3
\end{align}

のように、$f$はスカラーである必要があります。この時、

function test1()
    f(x, y) = 2 * x * y + 3
    x = 2
    y = 2
    dfdx = gradient(f, x, y)
    println(dfdx)
end
test1()

とすれば、出力は

(4.0, 4.0)

となり、これはグラディエント${\rm grad} f \equiv (\partial f/\partial x,\partial f/\partial y)$を計算していることになります。

では、$x$がベクトルの時はどうなるでしょうか。例えば、

\begin{align}
f(\vec{x}) = 2[\vec{x}]_1 [\vec{x}]_2+3
\end{align}

の時は、

function test1()
    f(x) = 2 * x[1] * x[2] + 3
    x = [2, 2]
    dfdx = gradient(f, x)
    println(dfdx)
end

test1()

となり、出力は

([4.0, 4.0],)

となります。これは、中身が1つのタプルとなっています。つまり、常に、変数の数の要素のタプルに微分が計算されて入ります。もう一つ重要なことは、

  • 出力された微分は、入力された引数と同じ型を持つ

ということです。いま、ベクトル$\vec{x}$を入力としたので、出力はベクトルになっています。一方、2変数の場合には、それぞれの変数はスカラーなので、スカラーの出力が2つ出てきました。これは当たり前に思えるかもしれませんが、rruleを作るときに混乱しがちなので、頭の片隅にでも置いておいてください。
整理すると、

  • 関数f(x,y,z,t)のようなものがあったとき、引数(x1,y1,z1,t1)を与えた時、その場所での微分はgradient(f, x1,y1,z1,t1)で計算され、出力は(df/dx,df/dy,df/dz,df/dt)というタプルで与えられる

ということです。

rruleの書き方

通常Juliaで定義されている型を使うのであれば、自動微分で簡単に微分を計算することができます。一方で、自分独自の型を定義して、その自動微分を計算したいということもあるでしょう。これまで、複素数の微分や行列の微分などをやってきていました。そのような独自の型で自動微分する際に必要な関数は、rruleです。これはChainRulesCoreというパッケージに入っている関数ですが、自分の型で自動微分したいときには、この関数を多重ディスパッチで自分の型用のrruleを定義することになります。

rruleが何を計算しているかについて述べます。そのため、やりたいことについて整理しておきましょう。
まず、ある関数

\begin{align}
f(x) = f(g_1(x),g_2(x),g_3(x),\cdots,g_N(x))
\end{align}

が定義されているときに、

\begin{align}
L(x) = L(f(x))
\end{align}

という量の微分$\partial L/\partial x$を求めることを考えます。ここで、上で述べましたように、$L(x)$という関数で得られる出力は実スカラーでなければなりません。そして、$f(x)$は$x$に直接依存しているわけではなく、$g_i(x)$を通じて$x$に依存しているとします。例えば、

\begin{align}
f(x) = g_1(x)g_2(x) + g_3(x)
\end{align}

などです。つまり、この関数$f$は

\begin{align}
f(g_1,g_2,g_3) = g_1(x)g_2(x) + g_3(x)
\end{align}

という三つの引数をもつ関数とみなすことができます。
これらを順番に計算する場合(左が入力、右が出力です)、

\begin{align}
x \rightarrow g_1(x),g_2(x),g_3(x) \\
g_1,g_2,g_3 \rightarrow f(g_1,g_2,g_3) \\
f \rightarrow L(f)
\end{align}

という形になっています。

計算したい量は$\partial L/\partial x$ですが、これは$g_i$の中身がわからないとそのまま微分できません。唯一わかっているのは$\partial g_i/\partial x$です。ですので、微分の連鎖律を用いれば、

\begin{align}
\frac{\partial L}{\partial x} &= \sum_i \frac{\partial L}{\partial g_i} \frac{\partial g_i}{\partial x} \\
\frac{\partial L}{\partial g_i} &=  \frac{\partial L}{\partial f} \frac{\partial f}{\partial g_i} 
\end{align}

とすれば、$L(f)$の$f$微分、$f$の$g$微分、$g$の$x$微分、という、それぞれ直接使っている変数による微分の計算によって、$\partial L/\partial x$が計算できるようになります。

なお、計算に必要なものは、

  • 関数g_i(x)の$x$微分
  • 関数f(g)の$g$微分
  • 関数L(f)の$f$微分

となります。

さて、ここまでは比較的わかりやすい話でした。次は、もう少し複雑にしてみます。例えば、

\begin{align}
x \rightarrow \hat{G}(x) \\
\hat{G} \rightarrow \hat{F}(\hat{G}) \\
\hat{F} \rightarrow L(\hat{F})
\end{align}

とします。ここで、ハットのついた量$\hat{G}$は行列だとします。最終的に出力される$L$は実スカラーです。この場合、連鎖律は、

\begin{align}
\frac{\partial L}{\partial x} &= \sum_{ij} \frac{\partial L}{\partial G_{ji}} \frac{\partial G_{ji}}{\partial x} \\
\frac{\partial L}{\partial G_{ji}} &= \sum_{\alpha \beta} \frac{\partial L}{\partial F_{\beta \alpha}} \frac{\partial F_{\beta \alpha}}{\partial G_{ji}} \\
\end{align}

となります。今回は、$\frac{\partial F_{\beta \alpha}}{\partial G_{ji}}$のようなものが登場しています。これは添え字が四つあるテンソルになっています。自動微分をするためには、このテンソルを計算する必要があるのでしょうか? 実は、テンソル自体は計算する必要はありません。なぜなら、テンソルはいつも連鎖律の中でのみ登場し、四つのうち二つの添え字は常に和を取られているからです。そこで、

\begin{align}
[{\cal \hat{B}}(\hat{L},\hat{F}(\hat{G}),\hat{G})]_{ij} \equiv \sum_{\alpha \beta} \hat{L}_{\alpha \beta} \frac{\partial F_{\beta \alpha}}{\partial G_{ji}} 
\end{align}

という関数さえ定義されていれば、連鎖律が計算できることになります。この${\cal B}$をpullbackと呼びます。そして、この${\cal B}$を計算する関数をrruleと呼んでいるのです。

このpullbackに関して気をつけなければならないことが2点あります。

  • ${\cal B}$の型と大きさは$\hat{G}$の型と大きさと同じである。つまり、$\hat{G}$が行列であれば${\cal B}$は同じサイズの行列。
  • $\hat{L}$の型と大きさは$\hat{F}$の型と大きさと同じである。つまり、$\hat{F}$が行列であれば、$\hat{L}$は同じ大きさの行列。

この2点は、一つの変数しかない関数の時はそんなに意識する必要はありませんが、変数が2つを超えると気をつけないとミスをします。

一般化しますと、ある関数$F_k$($k = 1,\cdots,M$)が$G_i$にあらわに依存している場合($F_k(G_1,G_2,\cdots,G_N)$)、そのpullbackは

\begin{align}
[{\cal B}(L,F(G),G)]_{i} \equiv \sum_{k} L_k \frac{\partial F_{k}}{\partial G_{i}} 
\end{align}

です。

2変数関数の場合、つまり、ある関数$F_k$($k = 1,\cdots,M$)が$G_i$と$w_l$にあらわに依存している場合($F_k(G_1,G_2,\cdots,G_N,w_1,\cdots,w_L)$)、pullbackは

\begin{align}
[{\cal B}(L,F(G,w),G)]_{i} \equiv \sum_{k} L_k \frac{\partial F_{k}}{\partial G_{i}} \\
[{\cal B}(L,F(G,w),w)]_{l} \equiv \sum_{k} L_k \frac{\partial F_{k}}{\partial w_{l}} 
\end{align}

の二つになります。ここで、前者は$N$個の要素を持ちますが、後者は$L$個の要素を持ちます。一般的には、変数$G$と変数$w$の型が違えば、pullbackの型も同様に違います。

実際の例

行列の型

実際の例で見てみます。使うのは、Juliaでの自動微分を使って行列で微分してみるで定義した型Fieldです。これは、1次元の格子点のそれぞれに行列が載っているような変数です。
ここに再掲します。

using ChainRulesCore
using LinearAlgebra
using Zygote

struct Field{Nc,L}
    A::Array{Float64,3}
end

function Field(Nc, L)
    A = zeros(Float64, Nc, Nc, L)
    return Field{Nc,L}(A)
end

function identity(Nc, L)
    A = zeros(Float64, Nc, Nc, L)
    for i = 1:L
        for ic = 1:Nc
            A[ic, ic, i] = 1
        end
    end
    return Field{Nc,L}(A)
end

function random_field(Nc, L)
    A = rand(Float64, Nc, Nc, L)
    return Field{Nc,L}(A)
end


function Base.copy(a::Field{Nc,L}) where {Nc,L}
    A = zeros(Float64, Nc, Nc, L)
    for i = 1:L
        for ic = 1:Nc
            for jc = 1:Nc
                A[jc, ic, i] = a.A[jc, ic, i]
            end
        end
    end
    return Field{Nc,L}(A)
end

function Base.zero(a::Field{Nc,L}) where {Nc,L}
    return Field(Nc, L)
end


function Base.:+(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        c.A[:, :, i] = view(a.A, :, :, i) .+ view(b.A, :, :, i)
    end
    return c
end

function Base.:-(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        c.A[:, :, i] = view(a.A, :, :, i) .- view(b.A, :, :, i)
    end
    return c
end


function Base.:*(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        mul!(view(c.A, :, :, i), view(a.A, :, :, i), view(b.A, :, :, i))
    end
    return c
end

function Base.:*(a::T, b::Field{Nc,L}) where {Nc,L,T<:Number}
    c = zero(b)
    for i = 1:L
        c.A[:, :, i] = a * view(b.A, :, :, i)
    end
    return c
end

function Base.:*(b::Field{Nc,L}, a::T) where {Nc,L,T<:Number}
    c = zero(b)
    for i = 1:L
        c.A[:, :, i] = a * view(b.A, :, :, i)
    end
    return c
end

function Base.display(a::Field{Nc,L}) where {Nc,L}
    for i = 1:L
        println("i = $i")
        display(a.A[:, :, i])
        println("\t")
    end
end

function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        for ic = 1:Nc
            for jc = 1:Nc
                c.A[ic, jc, i] = a.A[jc, ic, i]
            end
        end
    end
    return c
end


function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
    c = 0.0
    for i = 1:L
        for ic = 1:Nc
            c += a.A[ic, ic, i]
        end
    end
    return c
end



function ChainRulesCore.rrule(::typeof(+), a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    y = a + b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar
        fbbar = ybar
        return sbar, fabar, fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(*), a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar * b'
        fbbar = a' * ybar
        return sbar, fabar, fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(tr), a::Field{Nc,L}) where {Nc,L}
    y = tr(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = identity(Nc, L) * ybar
        return sbar, fbar
    end
    return y, pullback
end

function calc_f(x, d)
    c = tr(x * x + d)
    return c
end


これを使えば、行列同士の積の微分などが可能になります。

rruleの定義

さて、新しい関数として、

\begin{align}
[S(w_1,w_2,F)]_{ijn} = w_1 F_{ijn} + w_2 (F^2)_{ijn}
\end{align}

というものを考えます。ここで、$i$と$j$は行列の添え字、$n$は格子点を表す添え字です。そして、この関数を引数とした、

\begin{align}
f(S) = \sum_{n} \sum_{i} S_{iin}
\end{align}

という量を考えます。この関数$f$を$F$で微分したり$w_1$で微分したりしたいわけです。

まず、確かめ用の数値微分を行う関数を

function numerical_derivative(f, a::Field{Nc,L}, ic, jc, i) where {Nc,L}
    delta = 1e-9
    fa = f(a)
    a_p = copy(a)
    a_p.A[ic, jc, i] += delta
    fad = f(a_p)
    df = (fad - fa) / delta
    return df
end

とします。これはFieldで数値微分した値を出します。また、パラメータで数値微分した場合の計算をする関数を

function numerical_derivative_param(f, a::Field{Nc,L}, w) where {Nc,L}
    delta = 1e-9
    fa = f(a, w)
    fad = f(a, w + delta)
    df = (fad - fa) / delta
    return df
end

としておきます。これで、関数$S$を

function smearing(w1, w2, F::Field{Nc,L}) where {Nc,L}
    return w1 * F + w2 * F * F
end

と定義しておくことで、

function test()
    Nc = 2
    L = 4
    U0 = identity(Nc, L)
    a = random_field(Nc, L)
    b = random_field(Nc, L)

    w1 = 0.2
    w2 = 1.3
    f(x) = tr(smearing(w1, w2, x))
    c = 2 * a * b
    println(f(c))

    df = numerical_derivative(f, a, 1, 1, 1)
    println("numerical: ")
    println(df)

    f2(x, y) = tr(smearing(y, w2, x))
    dfdw_n = numerical_derivative_param(f2, a, w1)
    println("numerical dfdw: ")
    println(dfdw_n)

    return
end
test()

のようにすることで、数値微分を実行できます。

次に、自動微分です。Field型に対する自動微分を行うには、rruleを定義します。これは

function ChainRulesCore.rrule(::typeof(smearing), w1, w2, F::Field{Nc,L}) where {Nc,L}
    y = smearing(w1, w2, F)
    function pullback(ybar)
        sbar = NoTangent()
        w1bar = @thunk(sum(ybar.A .* F.A))
        w2bar = @thunk(sum(ybar.A .* (F * F).A))
        fbar = @thunk(w1 * ybar + w2 * (F * ybar + ybar * F))
        return sbar, w1bar, w2bar, fbar
    end
    return y, pullback
end

となります。前述しましたように、pullbackは変数の型と同じになりますから、$w_i$が実数の場合、w1barおよびw2barは実数の型、fbarFieldの型になります。また、w1barの計算では、Fieldの添字三種類全ての和をとっていることに注意してください。これは、連鎖律でそのように和を取るためにそうなっています。そして、ybarは$L$のことですから、$S$の型と等しくなっています。この場合はField型です。

このように定義することで、

function test()
    Nc = 2
    L = 4
    U0 = identity(Nc, L)
    a = random_field(Nc, L)
    b = random_field(Nc, L)

    w1 = 0.2
    w2 = 1.3
    f(x) = tr(smearing(w1, w2, x))
    c = 2 * a * b
    println(f(c))

    df = numerical_derivative(f, a, 1, 1, 1)
    println("numerical: ")
    println(df)
    ff = gradient(f, a)
    println("autodiff: ")
    println(ff[1].A[1, 1, 1])

    f2(x, y) = tr(smearing(y, w2, x))
    dfdw_n = numerical_derivative_param(f2, a, w1)
    println("numerical dfdw: ")
    println(dfdw_n)
    dfdw_a = gradient(w1 -> tr(smearing(w1, w2, a)), w1)[1]
    println("autodiff dfdw: ")
    println(dfdw_a)

    println(tr(a))

    return
end
test()

27.674958611716672
numerical: 
1.6709718053675715
autodiff: 
1.67097203939495
numerical dfdw: 
2.5759172572747957
autodiff dfdw: 
2.5759171386745447
2.5759171386745447

という出力になり、ちゃんと数値微分と自動微分が合っています。

5
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
7