Juliaでの自動微分について調べてみる 前半とJuliaでの自動微分について調べてみる 後半で、Juliaでの自動微分について調べました。この記事では、スカラー関数の行列微分をやってみることにします。
特に、自分で作った型に対して微分を定義して、それをZygoteの自動微分を使ってやってもらうことにします。
行列での微分の定義
ある関数$f$が、$n \times n$行列$A$に依存しているとします。例えば、
f(A) = {\rm tr}(A B)
という関数です。この関数は$n^2$個の$A_{ij}$を変数に持つ関数とも言うことができます。ですので、それぞれの変数で微分したものを
\frac{\partial f(A)}{\partial A_{ij}}
を考えることができます。この微分は$n^2$個あります。ここで、この微分をある行列の要素とみなせば、
\left[ \frac{\partial f(A)}{\partial A} \right]_{ij} \equiv \frac{\partial f(A)}{\partial A_{ij}}
という行列$\partial f/\partial A$を定義できます。ここで、行列の微分の定義にはもう一つありまして、
\left[ \frac{\partial f(A)}{\partial A} \right]_{ij} \equiv \frac{\partial f(A)}{\partial A_{ji}}
と$i,j$が逆になっているものもありますので、注意が必要です。
例
定義に従って、微分してみましょう。定義通りにやると
\frac{\partial f(A)}{\partial A_{ji}} = \frac{\partial {\rm tr}(A B)}{\partial A_{ji}} = \frac{\partial \sum_{l,m} A_{lm} B_{ml}}{\partial A_{ji}} = B_{ji}
なりますから、行列での微分は
\frac{\partial f(A)}{\partial A} = B^T
となります。
連鎖律とpullback
次に、連鎖律について考えます。ある関数$f(A)$の行列$A$が別の行列の関数となっている$A = g(B)$とします。この時、関数$f$の行列$B$での微分を考えます。行列の微分といっても要素ごとに考えれば通常の微分と変わりませんので、
\left[\frac{\partial f(A)}{\partial B} \right]_{ij} = \frac{\partial f(A)}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f(A)}{\partial A_{lm}}\frac{\partial A_{lm}}{\partial B_{ij}}
という形になります。
さて、Zygoteでの微分を行いたいため、pullbackを定義しようと思います。pullbackについて思い返してみると、ある関数$l(y)$があって、$y$は$x$の関数$y(x)$の時、$\partial l/\partial y \equiv \bar{y}$が与えられた時に、
{\cal B}_y(\bar{y}) \equiv \frac{\partial l}{\partial x} = \bar{x} = \bar{y} \frac{\partial y(x)}{\partial x}
と定義されているものでした。
いま、$x$や$y$が行列になっているため、$\partial l/\partial y$や$\partial l/\partial x$も行列です。一方、$\partial x/\partial y$は、行列を行列で微分することになりますから、$x$の行列の足と$y$の行列の足を指定して初めて一つの値が定まります形になっており、これはテンソルです。テンソルは扱うのが面倒そうですが、幸いなことに、pullbackは$\partial l/\partial x$ですから、行列です。つまり、pullbackを用いていれば、テンソルをあらわに扱う必要はありません。
そこで、pullbackを
\left[ {\cal B}_{B}(\bar{B}) \right]_{ij} \equiv \left[ \frac{\partial f}{\partial A} \right]_{ij} = \sum_{l,m} \frac{\partial f(A)}{\partial B_{lm}}\frac{\partial B_{lm}}{\partial A_{ij}}
と定義します。このpullbackを使って本当に微分ができるか試してみましょう。
$f(A) = {\rm tr} (A^2+ 1)$という関数の$A$微分を考えます。この関数は
B = A^2 \\
C = B + 1 \\
c = {\rm tr} C
という二つの形に分けることができます。連鎖律を使うと、
\frac{\partial f(A)}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f(B(A))}{\partial B_{lm}} \frac{\partial B_{lm}}{\partial A_{ij}} \\
\frac{\partial f(B))}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f(C(B))}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial B_{ij}} \\
\frac{\partial f(C)}{\partial C_{ij}} = \frac{\partial c}{\partial c} \frac{\partial c}{\partial C_{ij}}
となります。上で定義したpullbackを用いると、
\frac{\partial f(A)}{\partial A_{ij}} = \left[ {\cal B}_{B}(\bar{B}) \right]_{ij} = \sum_{l,m} \left[ {\cal B}_{C}(\bar{C}) \right]_{lm} \frac{\partial B_{lm}}{\partial A_{ij}} \\
\frac{\partial f(B))}{\partial B_{ij}} = \left[ {\cal B}_{C}(\bar{C}) \right]_{ij} = \sum_{l,m} \left[ {\cal B}_{c}(1) \right]_{lm} \frac{\partial C_{lm}}{\partial B_{ij}} \\
\frac{\partial f(C)}{\partial C_{ij}} =\left[ {\cal B}_{c}(1) \right]_{ij} = 1 \frac{\partial c}{\partial C_{ij}} \\
となりますから、下から順番に計算すれば、微分が求まることになります。
具体的に計算してみますと、
\frac{\partial B_{lm}}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{lk}A_{km}}{\partial A_{ij}} = A_{jm} \delta_{li} + A_{li} \delta_{im}
なので、
\frac{\partial f(A)}{\partial A_{ij}} = \left[ {\cal B}_{B}(\bar{B}) \right]_{ij} = \sum_{l,m} \left[ {\cal B}_{C}(\bar{C}) \right]_{lm} (A_{jm} \delta_{li} + A_{li} \delta_{im} ) = \sum_{m} \left[ {\cal B}_{C}(\bar{C}) \right]_{im} A_{jm} + \sum_{l} \left[ {\cal B}_{C}(\bar{C}) \right]_{li} A_{li} \\
= \left[{\cal B}_{C}(\bar{C}) A^T \right]_{ij} + \left[A^T {\cal B}_{C}(\bar{C}) \right]_{ij}
のようになります。ここで$\delta_{li}$はクロネッカーのデルタ($l = i$の時に1、$l \ne i$の時に0)です。
行列での微分
独自型の定義
ただの行列で微分するのもよいのですが、もう少し複雑なものを微分してみることにします。あるFieldという型を定義します。
struct Field{Nc,L}
A::Array{Float64,3}
end
このFieldという型は内部に3次元配列$A$を持っています。この三次元配列$A$は、1次元上の格子点の上に並んだ行列を表現しているとみなします。つまり、A = zeros(Float64,3,3,10)
であれば、10点の格子点の上に$3 \times 3$の零行列が置かれているとみなします。この時、このFieldを引数にとってスカラーを返す関数$f(A)$を考えます。この関数のFieldでの微分は、各格子点の上で行列微分したものです。物理学では、ある場の上に何か非可換な値が定義されていることがありますので、それを念頭においています。
このFieldを定義するために、行列のサイズNcと格子点の数Lを用いてField型を定義する関数:
function Field(Nc,L)
A = zeros(Float64,Nc,Nc,L)
return Field{Nc,L}(A)
end
を作ります。
次に、いくつかの行列を作成する関数を定義します。
格子点に単位行列を生成する関数:
function identity(Nc,L)
A = zeros(Float64,Nc,Nc,L)
for i=1:L
for ic=1:Nc
A[ic,ic,i] = 1
end
end
return Field{Nc,L}(A)
end
格子点に乱数行列を生成する関数:
function random_field(Nc,L)
A = rand(Float64,Nc,Nc,L)
return Field{Nc,L}(A)
end
そして、コピーの関数と同じサイズの零行列を作成する関数:
function Base.copy(a::Field{Nc,L}) where {Nc,L}
A = zeros(Float64,Nc,Nc,L)
for i=1:L
for ic=1:Nc
for jc=1:Nc
A[jc,ic,i] = a.A[jc,ic,i]
end
end
end
return Field{Nc,L}(A)
end
function Base.zero(a::Field{Nc,L}) where {Nc,L}
return Field(Nc,L)
end
を定義しておきます。
Field型同士の演算を定義する関数も定義します。
足し算と引き算:
function Base.:+(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.A,:,:,i)
end
return c
end
function Base.:-(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .- view(b.A,:,:,i)
end
return c
end
掛け算は、格子点上での行列同士の掛け算と、スカラーと行列の掛け算を定義しておきます:
function Base.:*(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.A,:,:,i))
end
return c
end
function Base.:*(a::T,b::Field{Nc,L}) where {Nc,L,T <: Number}
c = zero(b)
for i=1:L
c.A[:,:,i] = a*view(b.A,:,:,i)
end
return c
end
function Base.:*(b::Field{Nc,L},a::T) where {Nc,L,T <: Number}
c = zero(b)
for i=1:L
c.A[:,:,i] = a*view(b.A,:,:,i)
end
return c
end
表示用の関数として
function Base.display(a::Field{Nc,L}) where {Nc,L}
for i=1:L
println("i = $i")
display(a.A[:,:,i])
println("\t")
end
end
も定義しておきます。
あとは、それぞれの格子点での行列を転置する関数:
function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
for ic=1:Nc
for jc=1:Nc
c.A[ic,jc,i] = a.A[jc,ic,i]
end
end
end
return c
end
を定義します。
最後に、各格子点での行列のトレースをとって、それの全ての和をとる関数をField型に対するトレースとして定義しておきます。
using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
c = 0.0
for i=1:L
for ic=1:Nc
c += a.A[ic,ic,i]
end
end
return c
end
独自型のテスト
ここまで定義したField型をテストしてみます。
using LinearAlgebra
function test()
Nc = 2
L = 4
U0 = identity(Nc,L)
a = random_field(Nc,L)
b = random_field(Nc,L)
c = a*b
display(c)
d = a+b
display(d)
println(tr(c))
return
end
test()
これを実行すると、
i = 1
2×2 Matrix{Float64}:
0.291154 0.354011
0.748668 0.664093
i = 2
2×2 Matrix{Float64}:
0.175562 0.875757
0.522483 0.938511
i = 3
2×2 Matrix{Float64}:
0.455412 0.816125
0.211613 0.379623
i = 4
2×2 Matrix{Float64}:
0.363904 0.424232
0.617319 0.684216
i = 1
2×2 Matrix{Float64}:
1.02684 0.757801
0.981117 1.41711
i = 2
2×2 Matrix{Float64}:
0.901588 1.75285
0.695556 1.36848
i = 3
2×2 Matrix{Float64}:
1.2097 1.77622
0.42294 0.654737
i = 4
2×2 Matrix{Float64}:
0.909636 1.21337
1.00214 1.46033
3.952475313104248
のような出力になります。ちゃんと四則演算とトレースが計算できています。
微分の定義
pullbackを定義するため、いくつかの行列の微分を考えておきます。
行列の積
$C(A,B) = A B$の時、
\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial A_{ij}}
なので、
\frac{\partial C_{lm}}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{lk} B_{km}}{\partial A_{ij}} = \delta_{li} B_{jm}
となります。これを代入すると、
\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \delta_{li} B_{jm} = \sum_{m} \frac{\partial f}{\partial C_{im}} B_{jm} = \left[ \frac{\partial f}{\partial C} B^T \right]_{ij}
となります。次に、$B$で微分してみますと、
\frac{\partial f}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial B_{ij}}
より、
\frac{\partial C_{lm}}{\partial B_{ij}} = \frac{\partial \sum_{k} A_{lk} B_{km}}{\partial B_{ij}} = A_{li} \delta_{jm}
となりますから、
\frac{\partial f}{\partial B_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} A_{li} \delta_{jm} = \sum_{l} \frac{\partial f}{\partial C_{lj}} A_{li} = \left[ A^T \frac{\partial f}{\partial C} \right]_{ij}
となります。
行列の和
$C(A,B) = A + B$の時、
\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \frac{\partial C_{lm}}{\partial A_{ij}}
なので、
\frac{\partial C_{lm}}{\partial A_{ij}} = \frac{\partial (A_{lm} + B_{lm})}{\partial A_{ij}} = \delta_{li} \delta_{mj}
となりますから、これを代入すると、
\frac{\partial f}{\partial A_{ij}} = \sum_{l,m} \frac{\partial f}{\partial C_{lm}} \delta_{li} \delta_{mj}=\frac{\partial f}{\partial C_{ij}} = \left[ \frac{\partial f}{\partial C} \right]_{ij}
となります。$B$に関しても同様で、
\frac{\partial f}{\partial B_{ij}} = \left[ \frac{\partial f}{\partial C} \right]_{ij}
です。
トレース
$C(A) = {\rm tr} A$の時、
\frac{\partial f}{\partial A_{ij}} = \frac{\partial f}{\partial C} \frac{\partial C}{\partial A_{ij}}
となりますが、
\frac{\partial C}{\partial A_{ij}} = \frac{\partial \sum_{k} A_{kk} }{\partial A_{ij}} = \delta_{ij}
ですので、
\frac{\partial f}{\partial A_{ij}} = \frac{\partial f}{\partial C} \delta_{ij}
となります。
rruleの実装
あとは、前の記事と同様にrruleを実装します。行列の積と和とトレースを定義します。
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(+),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b'
fbbar = a'*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(tr),a::Field{Nc,L}) where {Nc,L}
y = tr(a)
function pullback(ybar)
sbar = NoTangent()
fbar = identity(Nc,L)*ybar
return sbar,fbar
end
return y, pullback
end
これでOKです。
自動微分
それでは、自動微分してみます。関数は${\rm tr}(A^2 + I)$とします。$I$は単位行列です。この関数は
function calc_f(x,d)
c = tr(x*x + d)
return c
end
で定義できます。
自動微分が正しいかを比較するために、数値微分も定義しておきます。Fieldの1箇所だけ微小変化させることで、数値微分を定義しておき、それを比較することにします。自動微分にはZygoteのgradient
を使うことにします。
コードは以下の通りです。
using Zygote
function test()
Nc = 2
L = 4
U0 = identity(Nc,L)
a = random_field(Nc,L)
b = random_field(Nc,L)
c = a*b
display(c)
d = a+b
display(d)
println(tr(c))
fa = calc_f(a,U0)
println(fa)
a_p = copy(a)
delta = 1e-9
a_p.A[1,1,1] += delta
fad = calc_f(a_p,U0)
df = (fad - fa)/delta
println(df)
ff = gradient(x -> calc_f(x,U0),a)
println(ff[1].A[1,1,1])
return
end
test()
これを実行すると、数値微分と自動微分がよく一致していることがわかると思います。
全体のコード
全体のコードを貼っておきます。
struct Field{Nc,L}
A::Array{Float64,3}
end
function Field(Nc,L)
A = zeros(Float64,Nc,Nc,L)
return Field{Nc,L}(A)
end
function identity(Nc,L)
A = zeros(Float64,Nc,Nc,L)
for i=1:L
for ic=1:Nc
A[ic,ic,i] = 1
end
end
return Field{Nc,L}(A)
end
function random_field(Nc,L)
A = rand(Float64,Nc,Nc,L)
return Field{Nc,L}(A)
end
function Base.copy(a::Field{Nc,L}) where {Nc,L}
A = zeros(Float64,Nc,Nc,L)
for i=1:L
for ic=1:Nc
for jc=1:Nc
A[jc,ic,i] = a.A[jc,ic,i]
end
end
end
return Field{Nc,L}(A)
end
function Base.zero(a::Field{Nc,L}) where {Nc,L}
return Field(Nc,L)
end
function Base.:+(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.A,:,:,i)
end
return c
end
function Base.:-(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .- view(b.A,:,:,i)
end
return c
end
function Base.:*(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.A,:,:,i))
end
return c
end
function Base.:*(a::T,b::Field{Nc,L}) where {Nc,L,T <: Number}
c = zero(b)
for i=1:L
c.A[:,:,i] = a*view(b.A,:,:,i)
end
return c
end
function Base.:*(b::Field{Nc,L},a::T) where {Nc,L,T <: Number}
c = zero(b)
for i=1:L
c.A[:,:,i] = a*view(b.A,:,:,i)
end
return c
end
function Base.display(a::Field{Nc,L}) where {Nc,L}
for i=1:L
println("i = $i")
display(a.A[:,:,i])
println("\t")
end
end
function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i=1:L
for ic=1:Nc
for jc=1:Nc
c.A[ic,jc,i] = a.A[jc,ic,i]
end
end
end
return c
end
using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
c = 0.0
for i=1:L
for ic=1:Nc
c += a.A[ic,ic,i]
end
end
return c
end
#=
using LinearAlgebra
function test()
Nc = 2
L = 4
U0 = identity(Nc,L)
a = random_field(Nc,L)
b = random_field(Nc,L)
c = a*b
display(c)
d = a+b
display(d)
println(tr(c))
return
end
test()
=#
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(+),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b'
fbbar = a'*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(tr),a::Field{Nc,L}) where {Nc,L}
y = tr(a)
function pullback(ybar)
sbar = NoTangent()
fbar = identity(Nc,L)*ybar
return sbar,fbar
end
return y, pullback
end
function calc_f(x,d)
c = tr(x*x + d)
return c
end
using Zygote
function test()
Nc = 2
L = 4
U0 = identity(Nc,L)
a = random_field(Nc,L)
b = random_field(Nc,L)
c = a*b
display(c)
d = a+b
display(d)
println(tr(c))
fa = calc_f(a,U0)
println(fa)
a_p = copy(a)
delta = 1e-9
a_p.A[1,1,1] += delta
fad = calc_f(a_p,U0)
df = (fad - fa)/delta
println(df)
ff = gradient(x -> calc_f(x,U0),a)
println(ff[1].A[1,1,1])
return
end
test()