Juliaでの自動微分を使って、ウィルティンガー微分してみるは個人的にまだ不満点がありましたので、それを改訂することにしました。
ウィルティンガー微分について、まず前の記事を再掲し、Juliaによる完全な自動微分を目指します。
ウィルティンガー微分とは
定義
複素数$z$を二つの実数$x,y$で表すと$z = x+iy$と書けますが、この複素数$z$の適当な領域で積分可能な関数$f(z) = u + iv$を考えます(u,vは実数)。この時、偏微分
\displaylines{
\frac{\partial f}{\partial x} = \frac{\partial u}{\partial x} + i \frac{\partial v}{\partial x} \\
\frac{\partial f}{\partial y} = \frac{\partial u}{\partial y} + i \frac{\partial v}{\partial y}
}
を考えることができます。関数$f$の全微分は
df = \frac{\partial f}{\partial x} dx + \frac{\partial f}{\partial y} dy
と書くとします。ここで、$z = x + iy$, $\bar{z} = x - iy$を導入すると、
\displaylines{
dz = dx + i dy \\
d\bar{z} = dx - i dy
}
から、
\displaylines{
dx = (dz + d\bar{z})/2 \\
dy = (dz - d\bar{z})/2i
}
が得られます。これを用いて全微分を書き直すと、
\displaylines{
df = \frac{\partial f}{\partial x} (dz + d\bar{z})/2 + \frac{\partial f}{\partial y} (dz - d\bar{z})/2i \\
= \frac{\partial f}{\partial z} dz + \frac{\partial f}{\partial \bar{z}} d\bar{z}
}
となります。ここで、
\displaylines{
\frac{\partial f}{\partial z} = \frac{1}{2} \left(\frac{\partial f}{\partial x} - i \frac{\partial f}{\partial y} \right) \\
\frac{\partial f}{\partial \bar{z}} = \frac{1}{2} \left(\frac{\partial f}{\partial x} + i \frac{\partial f}{\partial y} \right)
}
がウィルティンガー微分です。
性質
ウィルティンガー微分は複素数関数の微分の一種ですが、この微分は色々と便利な性質があります。例えば、線形性:
\displaylines{
\frac{\partial}{\partial z} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial z} + \beta \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial \bar{z}} + \beta \frac{\partial g}{\partial \bar{z}}
}
積の微分:
\displaylines{
\frac{\partial}{\partial z} (f g) = \frac{\partial f}{\partial z} g + f \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (f g) = \frac{\partial f}{\partial \bar{z}} g + f \frac{\partial g}{\partial \bar{z}}
}
連鎖律:
\displaylines{
\frac{\partial}{\partial z} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial z} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial z} \\
\frac{\partial}{\partial \bar{z}} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial \bar{z}} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial \bar{z}}
}
が成り立ちます。これらが成り立つということは、複素数関数の微分を通常の微分のように扱える、ということになります。
例えば、
\displaylines{
f(z) = z^4 + 2 z \bar{z} + z
}
という関数があるとします。この時、ウィルティンガー微分は、$z$と$\bar{z}$を独立変数と思って微分することで、
\displaylines{
\frac{\partial f}{\partial z} = 4 z^3 + 2 \bar{z} + 1 \\
\frac{\partial f}{\partial \bar{z}} = 2 z
}
と計算できます。
このように通常の微分のように扱えるということは、自動微分もきっとできるはずです。
実数での微分との違い
さて、ウィルティンガー微分は実数での微分と同じような形ですが、違う部分もあります。違いは、複素共役の数$\bar{z}$を独立に扱う必要がある、ということです。これは、実質2変数関数の微分のようになっていることを意味しています。例えば、実数部分をとる場合は、
{\rm Re} (z) = \frac{1}{2}(z + \bar{z})
となりますから、この関数は$z$で微分しても$\bar{z}$で微分しても有限の値が出ます。そして、
\displaylines{
\frac{\partial }{\partial z}{\rm Re} (z) =\frac{1}{2},\: \\
\frac{\partial }{\partial \bar{z}}{\rm Re} (z) =\frac{1}{2}
}
となりますから、通常の実数$x$を微分
\frac{\partial }{\partial x}{\rm Re} (x) = 1
とは値が異なっています。
前回の記事の問題点
以前の自動微分での不満点は、関数の実部の計算をする際に二変数関数を用意していました。というのは、ウィルティンガー微分では$z$と$\bar{z}$は独立だからです。しかし、コーディングをする際にわざわざ二変数関数を考えるのは面倒です。もう少し上手くできるはず、と思いました。
Juliaでのコーディング
なるべくシンプルに実装したいため、以前とstructを変更します。まず、複素変数としては
struct ComplexField{T}
z::T
end
とします。以前はAdjoint_ComplexField{T}
によって$\bar{z}$を定義していました。しかし、今回は使いません。このComplexField{T}
型に対する演算を
ComplexField(a::T) where T = ComplexField{T}(a)
function Base.adjoint(a::ComplexField{T}) where T
return ComplexField{T}(a.z')
end
function Base.:*(a::ComplexField,b::ComplexField)
return ComplexField(a.z*b.z)
end
function Base.:*(a::T,b::ComplexField) where T<:Number
return ComplexField(a*b.z)
end
function Base.:*(b::ComplexField,a::T) where T<:Number
return ComplexField(a*b.z)
end
function Base.:+(a::ComplexField,b::ComplexField)
return ComplexField(a.z+b.z)
end
と定義します。以前と違ってComplexField{T}
しかありませんので、組み合わせが減り、定義するメソッドが減りました。
次に、実部をとる関数を
function Base.real(a::ComplexField)
ar = (a.z+a.z')/2
return real(ar)
end
と定義します。
また、結果の正しさを確認するため、ウィルティンガー微分の定義通りに数値微分する関数を
function numerical_derivative(f,x::ComplexField)
delta = 1e-6
xd = ComplexField(x.z + delta)
fx = f(x)
fxd = f(xd)
fg_n = (fxd-fx)/delta
xd_im = ComplexField(x.z + im*delta)
fxd_im = f(xd_im)
fg_n_im = (fxd_im-fx)/delta
return (fg_n - im*fg_n_im)/2
end
と用意しておきます。
rruleの実装
realの実装
まず、上の関数たちが定義された状態で
f(z) = {\rm Re}(z)
を自動微分することを考えます。
using Zygote
function main()
a = ComplexField(2+3im)
b = ComplexField(4+im)
c = ComplexField(1+2im)
println(a*b)
println(a*b')
f(x) = real(x)
gnu = numerical_derivative(f,a)
println("Numerical grad: ", gnu)
g = gradient(f,a)[1]
println("Autograd: ", g)
end
main()
とすると、
ComplexField{Complex{Int64}}(5 + 14im)
ComplexField{Complex{Int64}}(11 + 10im)
Numerical grad: 0.500000000069889 - 0.0im
Autograd: (z = 1.0 + 0.0im,)
となり、我々が定義した${\rm Re}(x)$の微分とは異なってしまっています。
さて、$z$と$\bar{z}$の二変数を扱っていると考えると、これは$\vec{x} = (z,\bar{z})^T$という二変数のベクトルを扱っていると考えてもよいでしょう。通常のベクトルであれば、ある実スカラー関数をそのベクトルで「微分」すると
\frac{\partial f}{\partial \vec{x}} = (\frac{\partial f}{\partial x_1} ,\frac{\partial f}{\partial x_2})^T
というグラディエントを計算することと同じになっています。そこで、$z$と$\bar{z}$をまとめたベクトル$\vec{x}$を考えると、
\frac{\partial f}{\partial \vec{x}} = (\frac{\partial f}{\partial z} ,\frac{\partial f}{\partial \bar{z}})^T
としてもよいでしょう。
そして、ある実スカラー関数$f$が
f(\vec{x}) = f(\vec{g}(\vec{x}))
のようにベクトル$\vec{x}$を引数とするベクトル$\vec{g}(\vec{x})$の関数だとすると、微分の連鎖律は
\left[ \frac{\partial f}{\partial \vec{x}} \right]_i = \sum_{k} \frac{\partial f}{\partial g_k} \frac{\partial g_k}{\partial x_i}
となりますから、ある実スカラー関数$f$が
f(z,{\bar z}) = f(g(z,{\bar z}),\bar{g}(z,{\bar z}))
のように複素数の関数であれば、同様に
\displaylines{
\frac{\partial f}{\partial z} = \frac{\partial f}{\partial g} \frac{\partial g}{\partial z} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial z}\\
\frac{\partial f}{\partial \bar{z}} =\frac{\partial f}{\partial g} \frac{\partial g}{\partial \bar{z}}
+\frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial \bar{z}}
}
となるでしょう。これはすでにウィルティンガー微分の導入のところでも述べましたが、実はこれが非常に重要です。先程のコードでは
Autograd: (z = 1.0 + 0.0im,)
という出力が返ってきていましたが、これは値も間違っていましたが、型はComplexField
型でありません。しかし、定義どおりに考えると、実スカラー関数の複素数$z$微分は、ComplexField
型であるべきです。
そこで、rruleを
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real),a::T1) where T1 <: ComplexField
y = real(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ComplexField(ybar/2+0im)
return sbar,fbar
end
return y, pullback
end
とします。ある関数$f(x)$のpullbackは
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial f} \frac{\partial f}{\partial x} \equiv B_{f}(\frac{\partial L}{\partial f},x)
です。ここで、
\bar{f} \equiv \frac{\partial L}{\partial f}
とすると、
\displaylines{
\frac{\partial L}{\partial x} = \bar{f}\frac{\partial f}{\partial x} \equiv B_{f}(\bar{f},x)
}
です。もし$x$も$f(x)$がベクトルの場合には、
\frac{\partial L}{\partial x_i} = \sum_j\bar{f_j } \frac{\partial f_j}{\partial x_i}
となります。ここで、$L$は常に実スカラー関数としています。次に、$x$も$f(x)$も複素数であれば、
\frac{\partial L}{\partial z} = \bar{f} \frac{\partial f}{\partial z} +\bar{\bar{f}} \frac{\partial \bar{f}}{\partial z}
となります。
関数real
では、$f(z,\bar{z}) = (z + \bar{z})/2 $で実スカラーですから、インプットの$x$が複素数(二変数)、$f$を1変数と思えばいいので、
\frac{\partial L}{\partial z} = \bar{{\rm Re}} \frac{\partial {\rm Re}(z,\bar{z})}{\partial z} = \bar{{\rm Re}}/2
となります。
これで、
Numerical grad: 0.500000000069889 - 0.0im
Autograd: ComplexField{ComplexF64}(0.5 + 0.0im)
となり、良さそうな結果になります。
掛け算のrruleの実装
次に、
{\rm Re}(z z)
の$z$微分を計算してみます。
まず、
function main()
a = ComplexField(2+3im)
b = ComplexField(4+im)
c = ComplexField(1+2im)
println(a*b)
println(a*b')
f2(x) = real(x*x)
gnu = numerical_derivative(f2,a)
println("Numerical grad: ", gnu)
g = gradient(f2,a)[1]
println("Autograd: ", g)
end
はエラーになりますので、rruleを実装します。rruleは
function ChainRulesCore.rrule(::typeof(*),a::T1,b::T1) where T1 <: ComplexField
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b
fbbar = a*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
です。これは比較的そのままかと思います。これで、
Numerical grad: 2.0000005003240062 + 3.000000500463784im
Autograd: ComplexField{ComplexF64}(2.0 + 3.0im)
となります。
複素共役の自動微分の実装
次に、
{\rm Re}(\bar{z})
の自動微分を実装します。以前の記事ではこの微分を計算するために二変数関数を用意しました。今回は式の通り、一変数関数のまま実装します。
まず、このままf(x) = real(x')
としてもエラーが出て計算できません。そこで、rruleを実装します。Juliaでは、複素共役の記号x'
は関数Base.adjoint(x)
を意味していますので、このadjoint
に対するrruleを追加します。
pullbackを定義通りに計算します。定義に従って計算すると、
\displaylines{
\frac{\partial L}{\partial z} =\frac{\partial L}{\partial \bar{z}} \frac{\partial \bar{z}}{\partial z} +\frac{\partial L}{\partial \bar{\bar{z}}} \frac{\partial \bar{\bar{z}}}{\partial z} \\
= \frac{\partial L}{\partial \bar{z}} \frac{\partial \bar{z}}{\partial z} +\frac{\partial L}{\partial z} \frac{\partial z}{\partial z} \\
= \frac{\partial L}{\partial \bar{z}} \times 0+\frac{\partial L}{\partial z} \times 1 \\
= (\frac{\partial L}{\partial \bar{z}})' \times 1
}
となります。よって、rruleは
function ChainRulesCore.rrule(::typeof(Base.adjoint),a::T1) where T1 <: ComplexField
y =a'
function pullback(ybar)
sbar = NoTangent()
fbar = ybar' #ZeroTangent()
return sbar,fbar
end
return y, pullback
end
とすべきです。これで${\rm Re}(\bar{z})$の微分が計算できるようになりました。
足し算のrrule
足し算はシンプルです。
function ChainRulesCore.rrule(::typeof(+),a::T1,b::T1) where T1 <: ComplexField
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar#*b
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
です。
複雑な計算
これで、
function main()
a = ComplexField(2+3.0im)
b = ComplexField(4.0+im)
c = ComplexField(1+2.0im)
println(a*b)
println(a*b')
f2(x) = real(x*x'+x+x'+x*x*x')
gnu = numerical_derivative(f2,a)
println("Numerical grad: ", gnu)
g = gradient(f2,a)[1]
println("Autograd: ", g)
end
main()
のようなケースも計算できるようになりました。
全体コード
全体コードを最後に置いておきます。
struct ComplexField{T}
z::T
end
ComplexField(a::T) where T = ComplexField{T}(a)
function Base.adjoint(a::ComplexField{T}) where T
return ComplexField{T}(a.z')
end
function Base.:*(a::ComplexField,b::ComplexField)
return ComplexField(a.z*b.z)
end
function Base.:*(a::T,b::ComplexField) where T<:Number
return ComplexField(a*b.z)
end
function Base.:*(b::ComplexField,a::T) where T<:Number
return ComplexField(a*b.z)
end
function Base.real(a::ComplexField)
ar = (a.z+a.z')/2
return real(ar)
end
function Base.:+(a::ComplexField,b::ComplexField)
return ComplexField(a.z+b.z)
end
function Base.display(a::ComplexField)
display(a.z)
end
function numerical_derivative(f,x::ComplexField)
delta = 1e-6
xd = ComplexField(x.z + delta)
fx = f(x)
fxd = f(xd)
fg_n = (fxd-fx)/delta
xd_im = ComplexField(x.z + im*delta)
fxd_im = f(xd_im)
fg_n_im = (fxd_im-fx)/delta
return (fg_n - im*fg_n_im)/2
end
using Zygote
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real),a::T1) where T1 <: ComplexField
y = real(a)
function pullback(ybar)
sbar = NoTangent()
fbar = ComplexField(ybar/2+0im)
return sbar,fbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(*),a::T1,b::T1) where T1 <: ComplexField
y = a * b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar*b
fbbar = a*ybar
return sbar,fabar,fbbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(Base.adjoint),a::T1) where T1 <: ComplexField
y =a'
function pullback(ybar)
sbar = NoTangent()
fbar = ybar'
return sbar,fbar
end
return y, pullback
end
function ChainRulesCore.rrule(::typeof(+),a::T1,b::T1) where T1 <: ComplexField
y = a + b
function pullback(ybar)
sbar = NoTangent()
fabar = ybar#*b
fbbar = ybar
return sbar,fabar,fbbar
end
return y, pullback
end
using Zygote
function main()
a = ComplexField(2+3.0im)
b = ComplexField(4.0+im)
c = ComplexField(1+2.0im)
println(a*b)
println(a*b')
f2(x) = real(x*x'+x+x'+x*x*x')
gnu = numerical_derivative(f2,a)
println("Numerical grad: ", gnu)
g = gradient(f2,a)[1]
println("Autograd: ", g)
end
main()