8
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

JuliaLangAdvent Calendar 2022

Day 3

Juliaでの自動微分を使って、ウィルティンガー微分してみる

Last updated at Posted at 2022-12-02

Juliaでの自動微分を使って行列を引数にした複雑な関数を行列で微分してみる
の続編です。前の記事では行列で微分してみましたが、今度は複素数の微分であるウィルティンガー微分をしてみます。

注:完全版がこちらにあります。

ウィルティンガー微分とは

定義

複素数$z$を二つの実数$x,y$で表すと$z = x+iy$と書けますが、この複素数$z$の適当な領域で積分可能な関数$f(z) = u + iv$を考えます(u,vは実数)。この時、偏微分

\frac{\partial f}{\partial x} = \frac{\partial u}{\partial x} + i \frac{\partial v}{\partial x} \\
\frac{\partial f}{\partial y} = \frac{\partial u}{\partial y} + i \frac{\partial v}{\partial y}

を考えることができます。関数$f$の全微分は

df = \frac{\partial f}{\partial x}  dx + \frac{\partial f}{\partial y} dy

と書くとします。ここで、$z = x + iy$, $\bar{z} = x - iy$を導入すると、

dz = dx + i dy \\
d\bar{z} = dx - i dy

から、

dx = (dz + d\bar{z})/2 \\
dy = (dz - d\bar{z})/2i

が得られます。これを用いて全微分を書き直すと、

df = \frac{\partial f}{\partial x} (dz + d\bar{z})/2 + \frac{\partial f}{\partial y} (dz - d\bar{z})/2i \\
= \frac{\partial f}{\partial z}  dz + \frac{\partial f}{\partial \bar{z}} d\bar{z}

となります。ここで、

\frac{\partial f}{\partial z} = \frac{1}{2} \left(\frac{\partial f}{\partial x}  - i \frac{\partial f}{\partial y}  \right) \\
\frac{\partial f}{\partial \bar{z}} = \frac{1}{2} \left(\frac{\partial f}{\partial x}  + i \frac{\partial f}{\partial y}  \right) 

がウィルティンガー微分です。

性質

ウィルティンガー微分は複素数関数の微分の一種ですが、この微分は色々と便利な性質があります。例えば、線形性:

\frac{\partial}{\partial z} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial z} +  \beta \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (\alpha f + \beta g) = \alpha \frac{\partial f}{\partial \bar{z}} +  \beta \frac{\partial g}{\partial \bar{z}} 

積の微分:

\frac{\partial}{\partial z} (f g) = \frac{\partial f}{\partial z} g + f \frac{\partial g}{\partial z} \\
\frac{\partial}{\partial \bar{z}} (f g) = \frac{\partial f}{\partial \bar{z}} g + f \frac{\partial g}{\partial \bar{z}}

連鎖律:

\frac{\partial}{\partial z} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial z} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial z} \\
\frac{\partial}{\partial \bar{z}} f(g(z)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial \bar{z}} + \frac{\partial f}{\partial \bar{g}} \frac{\partial \bar{g}}{\partial \bar{z}}

が成り立ちます。これらが成り立つということは、複素数関数の微分を通常の微分のように扱える、ということになります。

例えば、

f(z) = z^4 + 2 z \bar{z} + z

という関数があるとします。この時、ウィルティンガー微分は、$z$と$\bar{z}$を独立変数と思って微分することで、

\frac{\partial f}{\partial z} = 4 z^3 + 2 \bar{z} + 1 \\
\frac{\partial f}{\partial \bar{z}} = 2 z  

と計算できます。
このように通常の微分のように扱えるということは、自動微分もきっとできるはずです。

自動微分

微分したい関数

さて、ウィルティンガー微分をJuliaの自動微分で実装してみましょう。微分したい関数を

y(z) = {\rm re} (f(z))

とします。出力を実数とすると、Zygote.jlにあるgradientで計算することができます。ですので、Zygoteのgradientで微分を計算することを目的とします。

型の定義

まず初めに、通常の微分と混ざって混乱することを防ぐために、独自型を定義しておきます。

abstract type CField end

struct ComplexField <: CField
    value::ComplexF64
end

struct Adjoint_ComplexField{T} <: CField
    parent::T
    Adjoint_ComplexField(a) = new{typeof(a)}(a)
end

function Base.adjoint(a::ComplexField)
    return  Adjoint_ComplexField(a)
end

function Base.adjoint(a::Adjoint_ComplexField)
    return  a.parent
end

function Base.display(a::ComplexField)
    display(a.value)
end

function Base.display(a::Adjoint_ComplexField)
    display(a.parent.value')
end

ここで、$z$と$\bar{z}$に相当する二つの型を導入しました。また、adjointというのは、Juliaではa''を実行した時に呼び出される関数でして、互いに複素共役の関係ですから移り変われるようにしておきます。

次に、演算を定義します。とりあえず掛け算を定義しておくと、

function Base.:*(a::ComplexField,b::ComplexField) 
    return ComplexField(a.value*b.value)
end

function Base.:*(a::Adjoint_ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.parent.value'*b.parent.value')
end

function Base.:*(a::Adjoint_ComplexField,b::ComplexField) 
    return ComplexField(a.parent.value'*b.value)
end

function Base.:*(a::ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.value*b.parent.value')
end

function Base.:*(a::ComplexField,b::T) where T <: Number
    return ComplexField(a.value*b)
end

function Base.:*(a::T,b::ComplexField) where T <: Number
    return ComplexField(a*b.value)
end

function Base.:*(a::Adjoint_ComplexField,b::T) where T <: Number
    return ComplexField(a.parent.value'*b)
end

function Base.:*(a::T,b::Adjoint_ComplexField) where T <: Number
    return ComplexField(a*b.parent.value')
end

となります。これらをテストするのであれば、

function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
end
test()

などとすると良いでしょう。

数値微分

これから作る自動微分の結果を確かめるために、数値微分の関数を作っておきます。ウィルティンガー微分の定義に従って、

function numerical_derivative(f,x::ComplexField)
    delta = 1e-6
    xd = ComplexField(x.value + delta)
    fx = f(x)
    fxd = f(xd)

    fg_n = (fxd-fx)/delta

    xd_im = ComplexField(x.value + im*delta)
    fxd_im = f(xd_im)
    fg_n_im = (fxd_im-fx)/delta
    return (fg_n - im*fg_n_im)/2
end

としました。これは、実数部分を微小に動かして$\partial/\partial x$を作り、虚数部分を動かして$\partial /\partial y$を作り、定義通りにウィルティンガー微分を計算したものです。

実数での微分との違い

さて、ウィルティンガー微分は実数での微分と同じような形ですが、違う部分もあります。違いは、複素共役の数$\bar{z}$を独立に扱う必要がある、ということです。これは、実質2変数関数の微分のようになっていることを意味しています。例えば、実数部分をとる場合は、

{\rm Re} (z) = \frac{1}{2}(z + \bar{z})

となりますから、この関数は$z$で微分しても$\bar{z}$で微分しても有限の値が出ます。そして、

\frac{\partial }{\partial z}{\rm Re} (z) =\frac{1}{2} \\
\frac{\partial }{\partial \bar{z}}{\rm Re} (z) =\frac{1}{2} 

となりますから、通常の実数$x$を微分

\frac{\partial }{\partial x}{\rm Re} (x) =  1

とは値が異なっています。これは、$z$と$\bar{z}$を独立に扱っているためです。そのため、関数${\rm Re}(z)$を

function real_c(a::ComplexField)
    return real(a.value)
end

function real_c(a::Adjoint_ComplexField)
    return real(a.parent.value)
end

と定義して、Zygoteで微分

using Zygote
function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
    f(x) = real_c(x)
    g = gradient(f,a)
    println("Autograd: ", g)

    gnu = numerical_derivative(f,a)
    println("Numerical grad: ", gnu)
end

すると、

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 1.0,),)
Numerical grad: 0.500000000069889 - 0.0im

となり、数値微分は正しいですが、自動微分はウィルティンガー微分の値1/2になってくれません。それではどうすれば良いでしょうか。

自動微分のルールの実装

実部を取る関数を微分したときに1/2になるようにするためには、この関数の微分を自分で定義してしまえばよいでしょう。つまり、ChainRulesCoreを用いて、

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real_c),a::T1) where {T1 <: CField} 
    y = real_c(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar/2
        return sbar,fbar
    end
    return y, pullback
end

と実装すると、real_cという関数を微分したときには自動的にこの関数が呼ばれて1/2がかえるようになります。

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (0.5,)
Numerical grad: 0.500000000069889 - 0.0im

しかし、これでは十分ではありません。次は、x*x'の自動微分をしてみることにします。

function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
    f(x) = real_c(x*x')
    g = gradient(f,a)
    println("Autograd: ", g)

    gnu = numerical_derivative(f,a)
    println("Numerical grad: ", gnu)
end

まず、これはエラーが出ます。

ERROR: LoadError: Need an adjoint for constructor ComplexField. Gradient is of type Float64

これは、ComplexFieldの微分が定義されていない、というエラーです。ですので、

function ChainRulesCore.rrule(::typeof(ComplexField),a::Number) 
    y = ComplexField(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar
        return sbar,fbar
    end
    return y, pullback
end

と、微分の定義を追加します。その結果、

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 2.0 + 3.0im,),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im

となります。間違ってますね!!!

間違いの修正

上では、コードではちゃんと自動微分ができているように見えますが、答えは間違っていました。何が起きたのでしょうか?その理由は、ウィルティンガー微分における複素共役の取り扱い方法にあります。ウィルティンガー微分では、複素数$z$の複素共役$\bar{z}$は、「独立変数」です。ですので、独立に扱って微分をしなければなりません。逆に考えれば、計算の途中で、$z$から複素共役を取ったりその逆をしたりを勝手にしてはいけません。最後まで独立変数だと思って取り扱う必要があります。
独立変数にする一番単純な方法は、関数を2変数にすることです。つまり、独立2変数の関数の自動微分を実装すれば、ちゃんとできるようになるはずです。

ということで、微分を計算したい関数を

function calc_f(x)
    xdag = x'
    a = x*xdag
    adag = xdag*x
    return real_c(a,adag)
end

とします。実部をとる関数が2変数関数になっています。

2変数関数の定義

実部をとる関数real_cを2変数関数:

function real_c(a::ComplexField,b::ComplexField)
    return real(a.value + b.value)/2
end

とその微分:

function ChainRulesCore.rrule(::typeof(real_c),a::T1,b::T2) where {T1 <: CField,T2 <: CField} 
    y = real_c(a,b)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar/2
        return sbar,fbar,fbar
    end
    return y, pullback
end

を定義します。これで、

using Zygote
function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
    f(x) = calc_f(x)
    g = gradient(f,a)
    println("Autograd: ", g)

    gnu = numerical_derivative(f,a)
    println("Numerical grad: ", gnu)
end

test()

を実行すると、

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: ((value = 4.0 + 6.0im,),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im

となります。間違ってますね!!!???

これを回避するためには、今回計算する自動微分に関わる演算をできるだけ自前で定義します。例えば、

function ChainRulesCore.rrule(::typeof(*),a::CField,b::CField) 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar*b
        fbbar = a*ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

と掛け算を定義します。これで、掛け算に関する場合にこちらの関数が呼ばれるようになります。また、これでは

ERROR: LoadError: MethodError: no method matching +(::ComplexField, ::ComplexField)

と怒られますので、足し算も定義しておきます。


function Base.:+(a::ComplexField,b::ComplexField) 
    return ComplexField(a.value+b.value)
end

function Base.:+(a::Adjoint_ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.parent.value'+b.parent.value')
end

function Base.:+(a::Adjoint_ComplexField,b::ComplexField) 
    return ComplexField(a.parent.value'+b.value)
end

function Base.:+(a::ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.value+b.parent.value')
end

function Base.:+(a::Adjoint_ComplexField,b::Number) 
    return ComplexField(a.parent.value'+b)
end

function Base.:+(a::ComplexField,b::Number) 
    return ComplexField(a.value+b)
end

function Base.:+(a::Number,b::Adjoint_ComplexField) 
    return b + a
end

function Base.:+(a::Number,b::ComplexField) 
    return b+ a
end

さらに実行すると、

ERROR: LoadError: Need an adjoint for constructor Adjoint_ComplexField{ComplexField}. Gradient is of type ComplexField

と怒られます。これは、Adjoint_ComplexFieldの微分が定義されていないというエラーですので、

function ChainRulesCore.rrule(::typeof(Adjoint_ComplexField),a::ComplexField) 
    y = a'
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ZeroTangent()
        return sbar,fbar
    end
    return y, pullback
end

と複素共役の微分を定義しておきます。複素共役は独立変数ですから、この微分はゼロになります($\partial \bar{z}/\partial z = 0$)。

これでやっと、

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (ComplexField(2.0 - 3.0im),)
Numerical grad: 2.0000005003240062 - 3.000000500463784im

と正しい微分になりました。

ついでなのでもう少し複雑な微分をやってみます。

function calc_f2(x)
    xdag = x'
    a = x*xdag*x + x
    adag = xdag*x*xdag + xdag
    return real_c(a,adag)
end

これは、

f(z) = {\rm Re} (z z' z + z)

です。足し算が出てきましたので、足し算の自動微分を定義しておきます。

function ChainRulesCore.rrule(::typeof(+),a::CField,b::CField) 
    y = a + b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar
        fbbar = ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

その結果、

using Zygote
function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
    #f(x) = calc_f(x)
    f(x) = calc_f2(x)
    g = gradient(f,a)
    println("Autograd: ", g)

    gnu = numerical_derivative(f,a)
    println("Numerical grad: ", gnu)
end

test()

を実行すると、

2.0 + 3.0im
2.0 - 3.0im
26.0 + 0.0im
Autograd: (ComplexField(11.0 - 6.0im),)
Numerical grad: 11.000003002692438 - 6.000001000927568im

となり、ちゃんと自動微分できています。

全体コード

全体のコードは以下の通りです。

abstract type CField end

struct ComplexField <: CField
    value::ComplexF64
end

struct Adjoint_ComplexField{T} <: CField
    parent::T
    Adjoint_ComplexField(a) = new{typeof(a)}(a)
end

function Base.adjoint(a::ComplexField)
    return  Adjoint_ComplexField(a)
end

function Base.adjoint(a::Adjoint_ComplexField)
    return  a.parent
end

function Base.display(a::ComplexField)
    display(a.value)
end

function Base.display(a::Adjoint_ComplexField)
    display(a.parent.value')
end

function Base.:*(a::ComplexField,b::ComplexField) 
    return ComplexField(a.value*b.value)
end

function Base.:*(a::Adjoint_ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.parent.value'*b.parent.value')
end

function Base.:*(a::Adjoint_ComplexField,b::ComplexField) 
    return ComplexField(a.parent.value'*b.value)
end

function Base.:*(a::ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.value*b.parent.value')
end

function Base.:*(a::ComplexField,b::T) where T <: Number
    return ComplexField(a.value*b)
end

function Base.:*(a::T,b::ComplexField) where T <: Number
    return ComplexField(a*b.value)
end

function Base.:*(a::Adjoint_ComplexField,b::T) where T <: Number
    return ComplexField(a.parent.value'*b)
end

function Base.:*(a::T,b::Adjoint_ComplexField) where T <: Number
    return ComplexField(a*b.parent.value')
end

function Base.:+(a::ComplexField,b::ComplexField) 
    return ComplexField(a.value+b.value)
end

function Base.:+(a::Adjoint_ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.parent.value'+b.parent.value')
end

function Base.:+(a::Adjoint_ComplexField,b::ComplexField) 
    return ComplexField(a.parent.value'+b.value)
end

function Base.:+(a::ComplexField,b::Adjoint_ComplexField) 
    return ComplexField(a.value+b.parent.value')
end

function Base.:+(a::Adjoint_ComplexField,b::Number) 
    return ComplexField(a.parent.value'+b)
end

function Base.:+(a::ComplexField,b::Number) 
    return ComplexField(a.value+b)
end

function Base.:+(a::Number,b::Adjoint_ComplexField) 
    return b + a
end

function Base.:+(a::Number,b::ComplexField) 
    return b+ a
end


function numerical_derivative(f,x::ComplexField)
    delta = 1e-6
    xd = ComplexField(x.value + delta)
    fx = f(x)
    fxd = f(xd)

    fg_n = (fxd-fx)/delta

    xd_im = ComplexField(x.value + im*delta)
    fxd_im = f(xd_im)
    fg_n_im = (fxd_im-fx)/delta
    return (fg_n - im*fg_n_im)/2
end

function real_c(a::ComplexField)
    return real(a.value)
end

function real_c(a::Adjoint_ComplexField)
    return real(a.parent.value)
end

#追加
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(real_c),a::T1) where {T1 <: CField} 
    y = real_c(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar/2
        return sbar,fbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(Adjoint_ComplexField),a::ComplexField) 
    y = a'
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ZeroTangent()
        return sbar,fbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(ComplexField),a::Number) 
    y = ComplexField(a)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar
        return sbar,fbar
    end
    return y, pullback
end

function real_c(a::ComplexField,b::ComplexField)
    return real(a.value + b.value)/2
end

function ChainRulesCore.rrule(::typeof(real_c),a::T1,b::T2) where {T1 <: CField,T2 <: CField} 
    y = real_c(a,b)
    function pullback(ybar)
        sbar = NoTangent()
        fbar = ybar/2
        return sbar,fbar,fbar
    end
    return y, pullback
end


function ChainRulesCore.rrule(::typeof(*),a::CField,b::CField) 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar*b
        fbbar = a*ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(+),a::CField,b::CField) 
    y = a + b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar
        fbbar = ybar
        return sbar,fabar,fbbar
    end
    return y, pullback
end


function calc_f(x)
    xdag = x'
    a = x*xdag
    adag = xdag*x
    return real_c(a,adag)
end

function calc_f2(x)
    xdag = x'
    a = x*xdag*x + x
    adag = xdag*x*xdag + xdag
    return real_c(a,adag)
end


using Zygote
function test()
    a = ComplexField(2+3im)
    display(a)
    b = a'
    display(b)
    c = 2*a*b
    display(c)
    #f(x) = calc_f(x)
    f(x) = calc_f2(x)
    g = gradient(f,a)
    println("Autograd: ", g)

    gnu = numerical_derivative(f,a)
    println("Numerical grad: ", gnu)
end

test()

8
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?