7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Juliaでの自動微分を使って行列で微分してみる

Last updated at Posted at 2022-09-06

Juliaでの自動微分について調べてみる 前半Juliaでの自動微分について調べてみる 後半で、Juliaでの自動微分について調べました。この記事では、スカラー関数の行列微分をやってみることにします。

特に、自分で作った型に対して微分を定義して、それをZygoteの自動微分を使ってやってもらうことにします。

行列での微分の定義

ある関数$f$が、$n \times n$行列$A$に依存しているとします。例えば、

f(A) = {\rm tr}(A B)

という関数です。この関数は$n^2$個の$A_{ij}$を変数に持つ関数とも言うことができます。ですので、それぞれの変数で微分したものを

\frac{\partial f(A)}{\partial A_{ij}}

を考えることができます。この微分は$n^2$個あります。ここで、この微分をある行列の要素とみなせば、

\left[ \frac{\partial f(A)}{\partial A} \right]_{ij} \equiv \frac{\partial f(A)}{\partial A_{ij}}

という行列$\partial f/\partial A$を定義できます。ここで、行列の微分の定義にはもう一つありまして、

\left[ \frac{\partial f(A)}{\partial A} \right]_{ij} \equiv \frac{\partial f(A)}{\partial A_{ji}}

と$i,j$が逆になっているものもありますので、注意が必要です。

定義に従って、微分してみましょう。定義通りにやると

\frac{\partial f(A)}{\partial A_{ji}} = \frac{\partial {\rm tr}(A B)}{\partial A_{ji}} = \frac{\partial \sum_{l,m} A_{lm} B_{ml}}{\partial A_{ji}}  = B_{ji}

なりますから、行列での微分は

\frac{\partial f(A)}{\partial A} = B^T

となります。

連鎖律とpullback

次に、連鎖律について考えます。ある関数$f(A)$の行列$A$が別の行列の関数となっている$A = g(B)$とします。この時、関数$f$の行列$B$での微分を考えます。行列の微分といっても要素ごとに考えれば通常の微分と変わりませんので、

\left[\frac{\partial f(A)}{\partial B} \right]_{ij} = \frac{\partial f(A)}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f(A)}{\partial A_{lm}}\frac{\partial A_{lm}}{\partial B_{ij}}

という形になります。

さて、Zygoteでの微分を行いたいため、pullbackを定義しようと思います。pullbackについて思い返してみると、ある関数$l(y)$があって、$y$は$x$の関数$y(x)$の時、$\partial l/\partial y \equiv \bar{y}$が与えられた時に、

{\cal B}_y(\bar{y}) \equiv  \frac{\partial l}{\partial x}  = \bar{x} = \bar{y}  \frac{\partial y(x)}{\partial x}

と定義されているものでした。
いま、$x$や$y$が行列になっているため、$\partial l/\partial y$や$\partial l/\partial x$も行列です。一方、$\partial x/\partial y$は、行列を行列で微分することになりますから、$x$の行列の足と$y$の行列の足を指定して初めて一つの値が定まります形になっており、これはテンソルです。テンソルは扱うのが面倒そうですが、幸いなことに、pullbackは$\partial l/\partial x$ですから、行列です。つまり、pullbackを用いていれば、テンソルをあらわに扱う必要はありません。
そこで、pullbackを

\left[ {\cal B}_{B}(\bar{B}) \right]_{ij} \equiv \left[  \frac{\partial f}{\partial A}  \right]_{ij} = \sum_{l,m} \frac{\partial f(A)}{\partial B_{lm}}\frac{\partial B_{lm}}{\partial A_{ij}}

と定義します。このpullbackを使って本当に微分ができるか試してみましょう。

$f(A) = {\rm tr} (A^2+ 1)$という関数の$A$微分を考えます。この関数は

B = A^2 \\
C = B + 1 \\
c = {\rm tr} C

という二つの形に分けることができます。連鎖律を使うと、

\frac{\partial f(A)}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f(B(A))}{\partial B_{lm}} \frac{\partial B_{lm}}{\partial A_{ij}} \\
\frac{\partial f(B))}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f(C(B))}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial B_{ij}} \\
\frac{\partial f(C)}{\partial C_{ij}} = \frac{\partial c}{\partial c} \frac{\partial c}{\partial C_{ij}} 

となります。上で定義したpullbackを用いると、

\frac{\partial f(A)}{\partial A_{ij}} = \left[ {\cal B}_{B}(\bar{B}) \right]_{ij} = \sum_{l,m} \left[ {\cal B}_{C}(\bar{C}) \right]_{lm} \frac{\partial B_{lm}}{\partial A_{ij}}  \\
\frac{\partial f(B))}{\partial B_{ij}} = \left[ {\cal B}_{C}(\bar{C}) \right]_{ij} =  \sum_{l,m} \left[ {\cal B}_{c}(1) \right]_{lm} \frac{\partial C_{lm}}{\partial B_{ij}} \\
\frac{\partial f(C)}{\partial C_{ij}} =\left[ {\cal B}_{c}(1) \right]_{ij} = 1 \frac{\partial c}{\partial C_{ij}}  \\

となりますから、下から順番に計算すれば、微分が求まることになります。

具体的に計算してみますと、

\frac{\partial B_{lm}}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{lk}A_{km}}{\partial A_{ij}} = A_{jm} \delta_{li} +  A_{li} \delta_{im} 

なので、

\frac{\partial f(A)}{\partial A_{ij}} = \left[ {\cal B}_{B}(\bar{B}) \right]_{ij} = \sum_{l,m} \left[ {\cal B}_{C}(\bar{C}) \right]_{lm} (A_{jm} \delta_{li} +  A_{li} \delta_{im} ) = \sum_{m} \left[ {\cal B}_{C}(\bar{C}) \right]_{im} A_{jm}  +  \sum_{l} \left[ {\cal B}_{C}(\bar{C}) \right]_{li} A_{li}  \\
= \left[{\cal B}_{C}(\bar{C}) A^T \right]_{ij} + \left[A^T {\cal B}_{C}(\bar{C}) \right]_{ij}

のようになります。ここで$\delta_{li}$はクロネッカーのデルタ($l = i$の時に1、$l \ne i$の時に0)です。

行列での微分

独自型の定義

ただの行列で微分するのもよいのですが、もう少し複雑なものを微分してみることにします。あるFieldという型を定義します。

struct Field{Nc,L}
    A::Array{Float64,3}
end

このFieldという型は内部に3次元配列$A$を持っています。この三次元配列$A$は、1次元上の格子点の上に並んだ行列を表現しているとみなします。つまり、A = zeros(Float64,3,3,10)であれば、10点の格子点の上に$3 \times 3$の零行列が置かれているとみなします。この時、このFieldを引数にとってスカラーを返す関数$f(A)$を考えます。この関数のFieldでの微分は、各格子点の上で行列微分したものです。物理学では、ある場の上に何か非可換な値が定義されていることがありますので、それを念頭においています。

このFieldを定義するために、行列のサイズNcと格子点の数Lを用いてField型を定義する関数:

function Field(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end

を作ります。

次に、いくつかの行列を作成する関数を定義します。

格子点に単位行列を生成する関数:

function identity(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            A[ic,ic,i] = 1
        end
    end
    return Field{Nc,L}(A)
end

格子点に乱数行列を生成する関数:

function random_field(Nc,L)
    A = rand(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end

そして、コピーの関数と同じサイズの零行列を作成する関数:

function Base.copy(a::Field{Nc,L}) where {Nc,L}
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                A[jc,ic,i] = a.A[jc,ic,i]
            end
        end
    end
    return Field{Nc,L}(A)
end

function Base.zero(a::Field{Nc,L}) where {Nc,L}
    return Field(Nc,L)
end

を定義しておきます。

Field型同士の演算を定義する関数も定義します。
足し算と引き算:

function Base.:+(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.A,:,:,i)
    end
    return c
end

function Base.:-(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .- view(b.A,:,:,i)
    end
    return c
end

掛け算は、格子点上での行列同士の掛け算と、スカラーと行列の掛け算を定義しておきます:

function Base.:*(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.A,:,:,i))
    end
    return c
end

function Base.:*(a::T,b::Field{Nc,L}) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

function Base.:*(b::Field{Nc,L},a::T) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

表示用の関数として

function Base.display(a::Field{Nc,L}) where {Nc,L}
    for i=1:L
        println("i = $i")
        display(a.A[:,:,i])
        println("\t")
    end
end

も定義しておきます。

あとは、それぞれの格子点での行列を転置する関数:

function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                c.A[ic,jc,i] = a.A[jc,ic,i]
            end
        end
    end
    return c
end

を定義します。

最後に、各格子点での行列のトレースをとって、それの全ての和をとる関数をField型に対するトレースとして定義しておきます。

using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
    c = 0.0
    for i=1:L
        for ic=1:Nc
            c += a.A[ic,ic,i] 
        end
    end
    return c
end

独自型のテスト

ここまで定義したField型をテストしてみます。

using LinearAlgebra
function test()
    Nc = 2
    L = 4
    U0 = identity(Nc,L)
    a = random_field(Nc,L)
    b = random_field(Nc,L)
    c = a*b
    display(c)
    d = a+b
    display(d)
    println(tr(c))
    return
end
test()

これを実行すると、

i = 1
2×2 Matrix{Float64}:
 0.291154  0.354011
 0.748668  0.664093	
i = 2
2×2 Matrix{Float64}:
 0.175562  0.875757
 0.522483  0.938511	
i = 3
2×2 Matrix{Float64}:
 0.455412  0.816125
 0.211613  0.379623	
i = 4
2×2 Matrix{Float64}:
 0.363904  0.424232
 0.617319  0.684216	
i = 1
2×2 Matrix{Float64}:
 1.02684   0.757801
 0.981117  1.41711	
i = 2
2×2 Matrix{Float64}:
 0.901588  1.75285
 0.695556  1.36848	
i = 3
2×2 Matrix{Float64}:
 1.2097   1.77622
 0.42294  0.654737	
i = 4
2×2 Matrix{Float64}:
 0.909636  1.21337
 1.00214   1.46033	
3.952475313104248

のような出力になります。ちゃんと四則演算とトレースが計算できています。

微分の定義

pullbackを定義するため、いくつかの行列の微分を考えておきます。

行列の積

$C(A,B) = A B$の時、

\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial A_{ij}} 

なので、

\frac{\partial C_{lm}}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{lk} B_{km}}{\partial A_{ij}} = \delta_{li} B_{jm}

となります。これを代入すると、

\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}}  \delta_{li} B_{jm} = \sum_{m} \frac{\partial f}{\partial C_{im}} B_{jm} = \left[ \frac{\partial f}{\partial C} B^T  \right]_{ij}

となります。次に、$B$で微分してみますと、

\frac{\partial f}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial B_{ij}} 

より、

\frac{\partial C_{lm}}{\partial B_{ij}} = \frac{\partial \sum_{k} A_{lk} B_{km}}{\partial B_{ij}} = A_{li} \delta_{jm} 

となりますから、

\frac{\partial f}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}}  A_{li} \delta_{jm}  = \sum_{l} \frac{\partial f}{\partial C_{lj}} A_{li} = \left[ A^T \frac{\partial f}{\partial C} \right]_{ij}

となります。

行列の和

$C(A,B) = A + B$の時、

\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial A_{ij}} 

なので、

\frac{\partial C_{lm}}{\partial A_{ij}} = \frac{\partial (A_{lm} + B_{lm})}{\partial A_{ij}} = \delta_{li} \delta_{mj}

となりますから、これを代入すると、

\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \delta_{li} \delta_{mj}=\frac{\partial f}{\partial C_{ij}} = \left[ \frac{\partial f}{\partial C}   \right]_{ij}

となります。$B$に関しても同様で、

\frac{\partial f}{\partial B_{ij}} = \left[ \frac{\partial f}{\partial C}   \right]_{ij}

です。

トレース

$C(A) = {\rm tr} A$の時、

\frac{\partial f}{\partial A_{ij}} = \frac{\partial f}{\partial C} \frac{\partial C}{\partial A_{ij}} 

となりますが、

\frac{\partial C}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{kk} }{\partial A_{ij}} = \delta_{ij} 

ですので、

\frac{\partial f}{\partial A_{ij}} = \frac{\partial f}{\partial C} \delta_{ij} 

となります。

rruleの実装

あとは、前の記事と同様にrruleを実装します。行列の積と和とトレースを定義します。

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(+),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L} 
    y = a + b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar
        fbbar = ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L} 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar*b'
        fbbar = a'*ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(tr),a::Field{Nc,L}) where {Nc,L} 
    y = tr(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = identity(Nc,L)*ybar
        return sbar,fbar
    end
    return y, pullback
end

これでOKです。

自動微分

それでは、自動微分してみます。関数は${\rm tr}(A^2 + I)$とします。$I$は単位行列です。この関数は

function calc_f(x,d)
    c = tr(x*x + d)
    return c
end

で定義できます。

自動微分が正しいかを比較するために、数値微分も定義しておきます。Fieldの1箇所だけ微小変化させることで、数値微分を定義しておき、それを比較することにします。自動微分にはZygoteのgradientを使うことにします。
コードは以下の通りです。

using Zygote
function test()
    Nc = 2
    L = 4
    U0 = identity(Nc,L)
    a = random_field(Nc,L)
    b = random_field(Nc,L)
    c = a*b
    display(c)
    d = a+b
    display(d)
    println(tr(c))

    fa = calc_f(a,U0)
    println(fa)

    a_p = copy(a)
    delta = 1e-9
    a_p.A[1,1,1] += delta 
    fad = calc_f(a_p,U0)
    df = (fad - fa)/delta
    println(df)

    ff = gradient(x -> calc_f(x,U0),a)
    println(ff[1].A[1,1,1])

    return
end
test()

これを実行すると、数値微分と自動微分がよく一致していることがわかると思います。

全体のコード

全体のコードを貼っておきます。

struct Field{Nc,L}
    A::Array{Float64,3}
end

function Field(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end

function identity(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            A[ic,ic,i] = 1
        end
    end
    return Field{Nc,L}(A)
end

function random_field(Nc,L)
    A = rand(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end


function Base.copy(a::Field{Nc,L}) where {Nc,L}
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                A[jc,ic,i] = a.A[jc,ic,i]
            end
        end
    end
    return Field{Nc,L}(A)
end

function Base.zero(a::Field{Nc,L}) where {Nc,L}
    return Field(Nc,L)
end


function Base.:+(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.A,:,:,i)
    end
    return c
end

function Base.:-(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .- view(b.A,:,:,i)
    end
    return c
end


function Base.:*(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.A,:,:,i))
    end
    return c
end

function Base.:*(a::T,b::Field{Nc,L}) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

function Base.:*(b::Field{Nc,L},a::T) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

function Base.display(a::Field{Nc,L}) where {Nc,L}
    for i=1:L
        println("i = $i")
        display(a.A[:,:,i])
        println("\t")
    end
end

function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                c.A[ic,jc,i] = a.A[jc,ic,i]
            end
        end
    end
    return c
end

using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
    c = 0.0
    for i=1:L
        for ic=1:Nc
            c += a.A[ic,ic,i] 
        end
    end
    return c
end

#=

using LinearAlgebra
function test()
    Nc = 2
    L = 4
    U0 = identity(Nc,L)
    a = random_field(Nc,L)
    b = random_field(Nc,L)
    c = a*b
    display(c)
    d = a+b
    display(d)
    println(tr(c))
    return
end
test()

=#

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(+),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L} 
    y = a + b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar
        fbbar = ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L} 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar*b'
        fbbar = a'*ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(tr),a::Field{Nc,L}) where {Nc,L} 
    y = tr(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = identity(Nc,L)*ybar
        return sbar,fbar
    end
    return y, pullback
end

function calc_f(x,d)
    c = tr(x*x + d)
    return c
end

using Zygote
function test()
    Nc = 2
    L = 4
    U0 = identity(Nc,L)
    a = random_field(Nc,L)
    b = random_field(Nc,L)
    c = a*b
    display(c)
    d = a+b
    display(d)
    println(tr(c))

    fa = calc_f(a,U0)
    println(fa)

    a_p = copy(a)
    delta = 1e-9
    a_p.A[1,1,1] += delta 
    fad = calc_f(a_p,U0)
    df = (fad - fa)/delta
    println(df)

    ff = gradient(x -> calc_f(x,U0),a)
    println(ff[1].A[1,1,1])

    return
end
test()
7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?