Juliaでの自動微分を使って行列を引数にした複雑な関数を行列で微分してみる
の続編です。前の記事では行列で微分してみましたが、今度は複素数の微分であるウィルティンガー微分をしてみます。
注:完全版がこちらにあります。
ウィルティンガー微分とは
定義
複素数$z$を二つの実数$x,y$で表すと$z = x+iy$と書けますが、この複素数$z$の適当な領域で積分可能な関数$f(z) = u + iv$を考えます(u,vは実数)。この時、偏微分
\frac{\partial f}{\partial x} = \frac{\partial u}{\partial x} + i \frac{\partial v}{\partial x} \\
\frac{\partial f}{\partial y} = \frac{\partial u}{\partial y} + i \frac{\partial v}{\partial y}
を考えることができます。関数$f$の全微分は
df = \frac{\partial f}{\partial x} dx + \frac{\partial f}{\partial y} dy
と書くとします。ここで、$z = x + iy$, $\bar{z} = x - iy$を導入すると、
dz = dx + i dy \\
d\bar{z} = dx - i dy
から、
dx = (dz + d\bar{z})/2 \\
dy = (dz - d\bar{z})/2i
が得られます。これを用いて全微分を書き直すと、
df = \frac{\partial f}{\partial x} (dz + d\bar{z})/2 + \frac{\partial f}{\partial y} (dz - d\bar{z})/2i \\
= \frac{\partial f}{\partial z} dz + \frac{\partial f}{\partial \bar{z}} d\bar{z}
となります。ここで、
\frac{\partial f}{\partial z} = \frac{1}{2} \left(\frac{\partial f}{\partial x} - i \frac{\partial f}{\partial y} \right) \\
\frac{\partial f}{\partial \bar{z}} = \frac{1}{2} \left(\frac{\partial f}{\partial x} + i \frac{\partial f}{\partial y} \right)
がウィルティンガー微分です。
性質
ウィルティンガー微分は複素数関数の微分の一種ですが、この微分は色々と便利な性質があります。例えば、線形性:
\frac{\partial}{\partial z} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial z} + \beta \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial \bar{z}} + \beta \frac{\partial g}{\partial \bar{z}}
積の微分:
\frac{\partial}{\partial z} (f g) = \frac{\partial f}{\partial z} g + f \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (f g) = \frac{\partial f}{\partial \bar{z}} g + f \frac{\partial g}{\partial \bar{z}}
連鎖律:
\frac{\partial}{\partial z} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial z} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial z} \\
\frac{\partial}{\partial \bar{z}} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial \bar{z}} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial \bar{z}}
が成り立ちます。これらが成り立つということは、複素数関数の微分を通常の微分のように扱える、ということになります。
例えば、
f(z) = z^4 + 2 z \bar{z} + z
という関数があるとします。この時、ウィルティンガー微分は、$z$と$\bar{z}$を独立変数と思って微分することで、
\frac{\partial f}{\partial z} = 4 z^3 + 2 \bar{z} + 1 \\
\frac{\partial f}{\partial \bar{z}} = 2 z
と計算できます。
このように通常の微分のように扱えるということは、自動微分もきっとできるはずです。
自動微分
微分したい関数
さて、ウィルティンガー微分をJuliaの自動微分で実装してみましょう。微分したい関数を
y(z) = {\rm re} (f(z))
とします。出力を実数とすると、Zygote.jlにあるgradient
で計算することができます。ですので、Zygoteのgradient
で微分を計算することを目的とします。
型の定義
まず初めに、通常の微分と混ざって混乱することを防ぐために、独自型を定義しておきます。
abstract type CField end
struct ComplexField <: CField
value::ComplexF64
end
struct Adjoint_ComplexField{T} <: CField
parent::T
Adjoint_ComplexField(a) = new{typeof(a)}(a)
end
function Base.adjoint(a::ComplexField)
return Adjoint_ComplexField(a)
end
function Base.adjoint(a::Adjoint_ComplexField)
return a.parent
end
function Base.display(a::ComplexField)
display(a.value)
end
function Base.display(a::Adjoint_ComplexField)
display(a.parent.value')
end
ここで、$z$と$\bar{z}$に相当する二つの型を導入しました。また、adjoint
というのは、Juliaではa'
の'
を実行した時に呼び出される関数でして、互いに複素共役の関係ですから移り変われるようにしておきます。
次に、演算を定義します。とりあえず掛け算を定義しておくと、
function Base.:*(a::ComplexField,b::ComplexField)
return ComplexField(a.value*b.value)
end
function Base.:*(a::Adjoint_ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.parent.value'*b.parent.value')
end
function Base.:*(a::Adjoint_ComplexField,b::ComplexField)
return ComplexField(a.parent.value'*b.value)
end
function Base.:*(a::ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.value*b.parent.value')
end
function Base.:*(a::ComplexField,b::T) where T <: Number
return ComplexField(a.value*b)
end
function Base.:*(a::T,b::ComplexField) where T <: Number
return ComplexField(a*b.value)
end
function Base.:*(a::Adjoint_ComplexField,b::T) where T <: Number
return ComplexField(a.parent.value'*b)
end
function Base.:*(a::T,b::Adjoint_ComplexField) where T <: Number
return ComplexField(a*b.parent.value')
end
となります。これらをテストするのであれば、
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
end
test()
などとすると良いでしょう。
数値微分
これから作る自動微分の結果を確かめるために、数値微分の関数を作っておきます。ウィルティンガー微分の定義に従って、
function numerical_derivative(f,x::ComplexField)
delta = 1e-6
xd = ComplexField(x.value + delta)
fx = f(x)
fxd = f(xd)
fg_n = (fxd-fx)/delta
xd_im = ComplexField(x.value + im*delta)
fxd_im = f(xd_im)
fg_n_im = (fxd_im-fx)/delta
return (fg_n - im*fg_n_im)/2
end
としました。これは、実数部分を微小に動かして$\partial/\partial x$を作り、虚数部分を動かして$\partial /\partial y$を作り、定義通りにウィルティンガー微分を計算したものです。
実数での微分との違い
さて、ウィルティンガー微分は実数での微分と同じような形ですが、違う部分もあります。違いは、複素共役の数$\bar{z}$を独立に扱う必要がある、ということです。これは、実質2変数関数の微分のようになっていることを意味しています。例えば、実数部分をとる場合は、
{\rm Re} (z) = \frac{1}{2}(z + \bar{z})
となりますから、この関数は$z$で微分しても$\bar{z}$で微分しても有限の値が出ます。そして、
\frac{\partial }{\partial z}{\rm Re} (z) =\frac{1}{2} \\
\frac{\partial }{\partial \bar{z}}{\rm Re} (z) =\frac{1}{2}
となりますから、通常の実数$x$を微分
\frac{\partial }{\partial x}{\rm Re} (x) = 1
とは値が異なっています。これは、$z$と$\bar{z}$を独立に扱っているためです。そのため、関数${\rm Re}(z)$を
function real_c(a::ComplexField)
return real(a.value)
end
function real_c(a::Adjoint_ComplexField)
return real(a.parent.value)
end
と定義して、Zygoteで微分
using Zygote
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
f(x) = real_c(x)
g = gradient(f,a)
println("Autograd: ", g)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
end
すると、
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 1.0,),)
Numerical grad: 0.500000000069889 - 0.0im
となり、数値微分は正しいですが、自動微分はウィルティンガー微分の値1/2になってくれません。それではどうすれば良いでしょうか。
自動微分のルールの実装
実部を取る関数を微分したときに1/2になるようにするためには、この関数の微分を自分で定義してしまえばよいでしょう。つまり、ChainRulesCoreを用いて、
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real_c),a::T1) where {T1 <: CField}
y = real_c(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar/2
return sbar,fbar
end
return y, pullback
end
と実装すると、real_c
という関数を微分したときには自動的にこの関数が呼ばれて1/2
がかえるようになります。
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (0.5,)
Numerical grad: 0.500000000069889 - 0.0im
しかし、これでは十分ではありません。次は、x*x'
の自動微分をしてみることにします。
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
f(x) = real_c(x*x')
g = gradient(f,a)
println("Autograd: ", g)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
end
まず、これはエラーが出ます。
ERROR: LoadError: Need an adjoint for constructor ComplexField. Gradient is of type Float64
これは、ComplexField
の微分が定義されていない、というエラーです。ですので、
function ChainRulesCore.rrule(::typeof(ComplexField),a::Number)
y = ComplexField(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar
return sbar,fbar
end
return y, pullback
end
と、微分の定義を追加します。その結果、
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 2.0 + 3.0im,),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im
となります。間違ってますね!!!
間違いの修正
上では、コードではちゃんと自動微分ができているように見えますが、答えは間違っていました。何が起きたのでしょうか?その理由は、ウィルティンガー微分における複素共役の取り扱い方法にあります。ウィルティンガー微分では、複素数$z$の複素共役$\bar{z}$は、「独立変数」です。ですので、独立に扱って微分をしなければなりません。逆に考えれば、計算の途中で、$z$から複素共役を取ったりその逆をしたりを勝手にしてはいけません。最後まで独立変数だと思って取り扱う必要があります。
独立変数にする一番単純な方法は、関数を2変数にすることです。つまり、独立2変数の関数の自動微分を実装すれば、ちゃんとできるようになるはずです。
ということで、微分を計算したい関数を
function calc_f(x)
xdag = x'
a = x*xdag
adag = xdag*x
return real_c(a,adag)
end
とします。実部をとる関数が2変数関数になっています。
2変数関数の定義
実部をとる関数real_c
を2変数関数:
function real_c(a::ComplexField,b::ComplexField)
return real(a.value + b.value)/2
end
とその微分:
function ChainRulesCore.rrule(::typeof(real_c),a::T1,b::T2) where {T1 <: CField,T2 <: CField}
y = real_c(a,b)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar/2
return sbar,fbar,fbar
end
return y, pullback
end
を定義します。これで、
using Zygote
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
f(x) = calc_f(x)
g = gradient(f,a)
println("Autograd: ", g)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
end
test()
を実行すると、
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 4.0 + 6.0im,),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im
となります。間違ってますね!!!???
これを回避するためには、今回計算する自動微分に関わる演算をできるだけ自前で定義します。例えば、
function ChainRulesCore.rrule(::typeof(*),a::CField,b::CField)
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b
fbbar = a*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
と掛け算を定義します。これで、掛け算に関する場合にこちらの関数が呼ばれるようになります。また、これでは
ERROR: LoadError: MethodError: no method matching +(::ComplexField, ::ComplexField)
と怒られますので、足し算も定義しておきます。
function Base.:+(a::ComplexField,b::ComplexField)
return ComplexField(a.value+b.value)
end
function Base.:+(a::Adjoint_ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.parent.value'+b.parent.value')
end
function Base.:+(a::Adjoint_ComplexField,b::ComplexField)
return ComplexField(a.parent.value'+b.value)
end
function Base.:+(a::ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.value+b.parent.value')
end
function Base.:+(a::Adjoint_ComplexField,b::Number)
return ComplexField(a.parent.value'+b)
end
function Base.:+(a::ComplexField,b::Number)
return ComplexField(a.value+b)
end
function Base.:+(a::Number,b::Adjoint_ComplexField)
return b + a
end
function Base.:+(a::Number,b::ComplexField)
return b+ a
end
さらに実行すると、
ERROR: LoadError: Need an adjoint for constructor Adjoint_ComplexField{ComplexField}. Gradient is of type ComplexField
と怒られます。これは、Adjoint_ComplexField
の微分が定義されていないというエラーですので、
function ChainRulesCore.rrule(::typeof(Adjoint_ComplexField),a::ComplexField)
y = a'
function pullback(ybar)
sbar = NoTangent()
fbar = ZeroTangent()
return sbar,fbar
end
return y, pullback
end
と複素共役の微分を定義しておきます。複素共役は独立変数ですから、この微分はゼロになります($\partial \bar{z}/\partial z = 0$)。
これでやっと、
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (ComplexField(2.0 - 3.0im),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im
と正しい微分になりました。
ついでなのでもう少し複雑な微分をやってみます。
function calc_f2(x)
xdag = x'
a = x*xdag*x + x
adag = xdag*x*xdag + xdag
return real_c(a,adag)
end
これは、
f(z) = {\rm Re} (z z' z + z)
です。足し算が出てきましたので、足し算の自動微分を定義しておきます。
function ChainRulesCore.rrule(::typeof(+),a::CField,b::CField)
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
その結果、
using Zygote
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
#f(x) = calc_f(x)
f(x) = calc_f2(x)
g = gradient(f,a)
println("Autograd: ", g)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
end
test()
を実行すると、
2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (ComplexField(11.0 - 6.0im),)
Numerical grad: 11.000003002692438 - 6.000001000927568im
となり、ちゃんと自動微分できています。
全体コード
全体のコードは以下の通りです。
abstract type CField end
struct ComplexField <: CField
value::ComplexF64
end
struct Adjoint_ComplexField{T} <: CField
parent::T
Adjoint_ComplexField(a) = new{typeof(a)}(a)
end
function Base.adjoint(a::ComplexField)
return Adjoint_ComplexField(a)
end
function Base.adjoint(a::Adjoint_ComplexField)
return a.parent
end
function Base.display(a::ComplexField)
display(a.value)
end
function Base.display(a::Adjoint_ComplexField)
display(a.parent.value')
end
function Base.:*(a::ComplexField,b::ComplexField)
return ComplexField(a.value*b.value)
end
function Base.:*(a::Adjoint_ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.parent.value'*b.parent.value')
end
function Base.:*(a::Adjoint_ComplexField,b::ComplexField)
return ComplexField(a.parent.value'*b.value)
end
function Base.:*(a::ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.value*b.parent.value')
end
function Base.:*(a::ComplexField,b::T) where T <: Number
return ComplexField(a.value*b)
end
function Base.:*(a::T,b::ComplexField) where T <: Number
return ComplexField(a*b.value)
end
function Base.:*(a::Adjoint_ComplexField,b::T) where T <: Number
return ComplexField(a.parent.value'*b)
end
function Base.:*(a::T,b::Adjoint_ComplexField) where T <: Number
return ComplexField(a*b.parent.value')
end
function Base.:+(a::ComplexField,b::ComplexField)
return ComplexField(a.value+b.value)
end
function Base.:+(a::Adjoint_ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.parent.value'+b.parent.value')
end
function Base.:+(a::Adjoint_ComplexField,b::ComplexField)
return ComplexField(a.parent.value'+b.value)
end
function Base.:+(a::ComplexField,b::Adjoint_ComplexField)
return ComplexField(a.value+b.parent.value')
end
function Base.:+(a::Adjoint_ComplexField,b::Number)
return ComplexField(a.parent.value'+b)
end
function Base.:+(a::ComplexField,b::Number)
return ComplexField(a.value+b)
end
function Base.:+(a::Number,b::Adjoint_ComplexField)
return b + a
end
function Base.:+(a::Number,b::ComplexField)
return b+ a
end
function numerical_derivative(f,x::ComplexField)
delta = 1e-6
xd = ComplexField(x.value + delta)
fx = f(x)
fxd = f(xd)
fg_n = (fxd-fx)/delta
xd_im = ComplexField(x.value + im*delta)
fxd_im = f(xd_im)
fg_n_im = (fxd_im-fx)/delta
return (fg_n - im*fg_n_im)/2
end
function real_c(a::ComplexField)
return real(a.value)
end
function real_c(a::Adjoint_ComplexField)
return real(a.parent.value)
end
#追加
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real_c),a::T1) where {T1 <: CField}
y = real_c(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar/2
return sbar,fbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(Adjoint_ComplexField),a::ComplexField)
y = a'
function pullback(ybar)
sbar = NoTangent()
fbar = ZeroTangent()
return sbar,fbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(ComplexField),a::Number)
y = ComplexField(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar
return sbar,fbar
end
return y, pullback
end
function real_c(a::ComplexField,b::ComplexField)
return real(a.value + b.value)/2
end
function ChainRulesCore.rrule(::typeof(real_c),a::T1,b::T2) where {T1 <: CField,T2 <: CField}
y = real_c(a,b)
function pullback(ybar)
sbar = NoTangent()
fbar = ybar/2
return sbar,fbar,fbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(*),a::CField,b::CField)
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b
fbbar = a*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(+),a::CField,b::CField)
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function calc_f(x)
xdag = x'
a = x*xdag
adag = xdag*x
return real_c(a,adag)
end
function calc_f2(x)
xdag = x'
a = x*xdag*x + x
adag = xdag*x*xdag + xdag
return real_c(a,adag)
end
using Zygote
function test()
a = ComplexField(2+3im)
display(a)
b = a'
display(b)
c = 2*a*b
display(c)
#f(x) = calc_f(x)
f(x) = calc_f2(x)
g = gradient(f,a)
println("Autograd: ", g)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
end
test()