Juliaでの自動微分を使って行列で微分してみる
の続編です。さらに複雑な関数を微分できるようにしたいので、それをやってみたいと思います。
計算したい量
前回は、
f(\{ U \}) = \sum_{k=1}^{L}{\rm tr} (U_k U_k + I)
という量を計算し、その微分:
\frac{\partial f(\{ U \})}{\partial [U_{k}]_{ij}}
を自動微分で計算しました。ここで、$U_k$は$k$番目の格子点の上に定義された$n \times n$の行列で、コード上では
struct Field{Nc,L}
A::Array{Float64,3}
end
と定義されています。
次は、
f(\{ U \}) = \sum_{k=1}^{L}{\rm tr} (U_k U_{k+1} U_{k+2})
を計算し、その行列微分を自動で計算してみましょう。なお、$U_{L+1} = U_1$及び$U_{L+2} = U_2$とします。
シフトした量の定義
$U_{k+1}$という行列をどうやって扱うか、ですが、これはそれ専用の型を定義してしまいましょう。
struct Shifted_Field{Nc,L,shift}
F::Field{Nc,L}
function Shifted_Field(F::Field{Nc,L},shift) where {Nc,L}
shiftedF = apply_shift(F,shift)
return new{Nc,L,shift}(shiftedF)
end
end
function get_shiftedi(i,L,shift)
shiftedi = i + shift
while shiftedi < 1 || shiftedi > L
shiftedi += ifelse(shiftedi > L,-L,0)
shiftedi += ifelse(shiftedi < 1,L,0)
end
return shiftedi
end
function apply_shift(F::Field{Nc,L},shift) where {Nc,L}
Fout = zero(F)
for i=1:L
shiftedi = get_shiftedi(i,L,shift)
for ic=1:Nc
for jc=1:Nc
Fout.A[jc,ic,i] = F.A[jc,ic,shiftedi]
end
end
end
return F
end
function shift_Field(F::Field{Nc,L},shift) where {Nc,L}
return Shifted_Field(F,shift)
end
function shift_Field(a::Shifted_Field{Nc,L,shift0},shift) where {Nc,L,shift0}
return Shifted_Field(a.F,shift0 + shift)
end
これで、Fshift = shift_Field(F,shift)
とすることで、格子点がshiftだけずれたものを得ることができます。あとは、このShifted_Field
型同士の積や、Shifted_Field
型とField
型の積を定義:
function Base.:*(a::Shifted_Field{Nc,L,shift},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
mul!(view(c.A,:,:,i),view(a.F.A,:,:,i),view(b.F.A,:,:,i))
end
return c
end
function Base.:*(a::T,b::Shifted_Field{Nc,L,shift}) where {Nc,L,T <: Number,shift}
c = zero(b.F)
for i=1:L
c.A[:,:,i] = a*view(b.F.A,:,:,i)
end
return c
end
function Base.:*(b::Shifted_Field{Nc,L,shift},a::T) where {Nc,L,T <: Number,shift}
c = zero(b.F)
for i=1:L
c.A[:,:,i] = a*view(b.F.A,:,:,i)
end
return c
end
function Base.:+(a::Shifted_Field{Nc,L,shift1},b::Shifted_Field{Nc,L,shift2}) where {Nc,L,shift1,shift2}
c = zero(a.F)
for i=1:L
c.A[:,:,i] = view(a.F.A,:,:,i) .+ view(b.F.A,:,:,i)
end
return c
end
function Base.:-(a::Shifted_Field{Nc,L,shift},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
c.A[:,:,i] = view(a.F.A,:,:,i) .- view(b.F.A,:,:,i)
end
return c
end
function Base.:*(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
mul!(view(c.A,:,:,i),view(a.F.A,:,:,i),view(b.A,:,:,i))
end
return c
end
function Base.:*(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
c = zero(a)
for i=1:L
mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.F.A,:,:,i))
end
return c
end
function Base.:+(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
c.A[:,:,i] = view(a.F.A,:,:,i) .+ view(b.A,:,:,i)
end
return c
end
function Base.:+(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.F.A,:,:,i)
end
return c
end
function Base.:-(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
c.A[:,:,i] = view(a.F.A,:,:,i) .- view(b.A,:,:,i)
end
return c
end
function Base.:-(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
c = zero(a)
for i=1:L
c.A[:,:,i] = view(a.A,:,:,i) .- view(b.F.A,:,:,i)
end
return c
end
function Base.adjoint(a::Shifted_Field{Nc,L}) where {Nc,L,shift}
c = zero(a.F)
for i=1:L
for ic=1:Nc
for jc=1:Nc
c.A[ic,jc,i] = a.F.A[jc,ic,i]
end
end
end
return c
end
しておけば、
function test2()
Nc = 2
L = 4
a = random_field(Nc,L)
ashift = shift_Field(a,1)
println(tr(a*ashift*a))
end
test2()
のようにすれば、
\sum_{k=1} {\rm tr} (U_k U_{k+1} U_k)
という量でも簡単に計算できます。ここで、Juliaでの自動微分を使って行列で微分してみる
で定義したものはすでに定義してあるとしています。
rruleの定義
連鎖律とpullback
自動微分をZygoteのやってもらうために、rruleを定義しておきます。まず、関数$f$の$[U_k]_{ij}$の微分の連鎖律
\frac{\partial f}{\partial [U_k]_{ij}} = \sum_{l,m,n} \frac{\partial f}{\partial [C_n]_{lm}} \frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}}
が書けることに注意します。pullbackを
\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} \equiv \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} \frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}}
と定義すると、
\frac{\partial f}{\partial [U_k]_{ij}} = \sum_n \left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij}
となります。
シフトした量の微分
$p$だけシフトしたもの$C_n({ U }) = U_{n+p}$の微分を考えます。これは、
\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial [U_{n+p}]_{lm}}{\partial [U_k]_{ij}} = \delta_{n,k-p} \delta_{li} \delta_{mj}
となり、pullbackは
\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} = \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} \delta_{n,k-p} \delta_{li} \delta_{mj} =\delta_{n,k-p} \frac{\partial f}{\partial [C_{n}]_{ij}}
となります。ですので、
function ChainRulesCore.rrule(::typeof(shift_Field),a::Field{Nc,L},shift) where {Nc,L}
y = shift_Field(a,shift)
function pullback(ybar)
sbar = NoTangent()
fabar = shift_Field(ybar,-shift)
return sbar,fabar
end
return y, pullback
end
と実装してしまいましょう。
掛け算
$C_n({ U },{ P }) = U_n P_{n+p}$の場合、
\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial \sum_{a} [U_n]_{la} [P_{n+p}]_{am}}{\partial [U_k]_{ij}} = [P_{n+p}]_{jm} \delta_{n,k} \delta_{li}
\frac{\partial [C_n]_{lm}}{\partial [P_k]_{ij}} = [U_{n}]_{li} \delta_{n+p,k} \delta_{mj}
となりますから、
pullbackは$U$に関する微分と$P$に関する微分の二種類が現れ、
\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} =
\left( \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [P_{n+p}]_{jm} \delta_{n,k} \delta_{li} , \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [U_{n}]_{li} \delta_{n+p,k} \delta_{mj} \right) \\
= \left( \sum_{m} \frac{\partial f}{\partial [C_n]_{im}} [P_{n+p}]_{jm} \delta_{n,k} ,\sum_{l} \frac{\partial f}{\partial [C_n]_{lj}} [U_{n}]_{li} \delta_{n+p,k} \right) \\
= \left( \delta_{n,k} \left[ \frac{\partial f}{\partial [C_n]} P_{n+p}^T \right]_{ij} ,\delta_{n,k-p} \left[ U_{n}^T \frac{\partial f}{\partial [C_n]}\right]_{ij} \right)
となります。
$C_n({ U },{ P }) = P_{n+p} U_n$の場合も同様に、
\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial \sum_{a} [P_{n+p}]_{la} [U_{n}]_{am}}{\partial [U_k]_{ij}} = [P_{n+p}]_{li} \delta_{n,k} \delta_{mj}
\frac{\partial [C_n]_{lm}}{\partial [P_k]_{ij}} = [U_{n}]_{jm} \delta_{n+p,k} \delta_{il}
となりますから、pullbackは
\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} =
\left( \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [P_{n+p}]_{li} \delta_{n,k}\delta_{mj} , \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [U_{n}]_{jm} \delta_{n+p,k} \delta_{il} \right) \\
= \left( \sum_{l} \frac{\partial f}{\partial [C_n]_{lj}} [P_{n+p}]_{li} \delta_{n,k}\ ,\sum_{m} \frac{\partial f}{\partial [C_n]_{im}} [U_{n}]_{jm} \delta_{n+p,k} \right) \\
= \left( \delta_{n,k} \left[ P_{n+p}^T \frac{\partial f}{\partial [C_n]} \right]_{ij} ,\delta_{n,k-p} \left[ \frac{\partial f}{\partial [C_n]} U_{n}^T \right]_{ij} \right)
となります。
よって、それぞれの関数は
function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b'
f2 = a'*ybar
fbbar = shift_Field(f2,-shift)
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(*),a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = a'*ybar
f2 = ybar*b'
fbbar = shift_Field(f2,-shift)
return sbar,fabar,fbbar
end
return y, pullback
end
となります。
自動微分の実行
あとは自動微分を試します。数式を関数にすると
function calc_tr(x)
xshift_1 = shift_Field(x,1)
xshift_2 = shift_Field(x,2)
xshift_3 = shift_Field(x,3)
return tr(x*xshift_1*xshift_2*xshift_3)
end
となります。これを自動微分してみて、数値微分と比較してみます。
function test()
Nc = 2
L = 4
a = random_field(Nc,L)
b = random_field(Nc,L)
c = a*b
display(c)
d = a+b
display(d)
println(tr(c))
a_p = copy(a)
delta = 1e-9
a_p.A[1,1,1] += delta
fa = calc_tr(a)
fad = calc_tr(a_p)
df = (fad - fa)/delta
println(df)
ff = gradient(x -> calc_tr(x),a)
println(ff[1].A[1,1,1])
return
end
test()
ちゃんと微分できていることがわかると思います。