6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

JuliaAdvent Calendar 2024

Day 20

新世代自動微分パッケージEnzymeをJuliaで試す

Last updated at Posted at 2024-12-19

Juliaは、LLVMという中間言語に変換した後にアーキテクチャに最適なコードを生成することで高速な動作を実現しています。最近、IntelのFortranやCもLLVMを使う形になっています。つまり、LLVMは重要な中間言語です。そこで、LLVMレベルで自動微分を実装することができれば、言語に依らずに関数の自動微分ができるわけです。

この考えを使ったフレームワークはEnzyme Automatic Differentiation Frameworkです。このフレームワークはC, C++, Swift, Julia, Rust, Fortran, TensorFlow等の言語に対応しています。論文はこちらです。

例えば、C言語であれば、

#include <stdio.h>

double square(double x) {
    return x * x;
}

double __enzyme_autodiff(void*, double);
int main() {
    double x = 3.14;
    // Evaluates to 2 * x = 6.28
    double grad_x = __enzyme_autodiff((void*)square, x);
    printf("square'(%f) = %f\n", x,  grad_x);
    return 0;
}

で関数の自動微分ができます。LLVMのレベルで自動微分を実装しているために、非常に高速に自動微分ができるとのことです。

Enzyme.jl

さて、Enzymeの開発者の方のセミナーを聞く機会があり、EnzymeのJulia版であるEnzyme.jlが非常に有用であることを知りました。そこで、Enzyme.jlを試してみます。

他の自動微分パッケージとの違い

Juliaにはさまざまな自動微分パッケージがあります。これまで私が書いてきた記事では、Flux.jlを扱ってきました。Flux.jlはZygote.jlを使っており、ChainRule.jlやChainRuleCore.jlを使っています。
Enzyme.jlと他の自動微分パッケージの最大の違いは、

  1. 配列を途中変更していても自動微分できる
  2. inplace計算ができる

です。1.に関しては、ZygoteやPyTorchなどを使ったことがある人は心当たりがあるかもしれませんが、これらのパッケージでは、計算の途中で変数を上書きするような計算をすると、自動微分ができなくなってしまっていました。Enzyme.jlの場合、できます。

2.に関しては、他のパッケージではy = f(x)のように出力yがあって、そのyを微分するという形でした。ですので、f!(y,x)のような、実行するとyが書き変わるような関数に対する微分を計算することはできませんでした。Enzyme.jlの場合、できます。

自動微分の例

まず、微分したい関数を長さ2のベクトル$\vec{x}$から長さ1のベクトル$\vec{y}$を作る関数:

 [\vec{y}]_1 = [\vec{x}]_1^2 + [\vec{x}]_2 [\vec{x}]_1 
using Enzyme

function f!(x::Array{Float64}, y::Array{Float64})
    y[1] = x[1] * x[1] + x[2] * x[1]
    return nothing
end

とします。この関数は出力はnothingで、yの値が書き変わる関数です。ですので、これまでのZygoteやFlux、そしてPyTorchなどでは自動微分できません。

微分を定義するために、以下のようなadjointというものを定義します。

\bar{a} \equiv \frac{\partial l}{\partial a}

これは、誤差逆伝播法を使うときに使います。例えば、yのx微分は、

\bar{x} \equiv \frac{\partial l}{\partial x} = \frac{\partial l}{\partial y}\frac{\partial y}{\partial a} = \bar{y} \frac{\partial y}{\partial x}

と書けます。ここで、yのx微分を計算したい場合には、$\bar{y}= 1$として誤差逆伝播することで、$\bar{x}$が$\frac{\partial y}{\partial x}$となります。
Enzymeでは、誤差逆伝播を使う微分の計算をReverseモードと呼びます。そして、微分は

x  = [2.0, 2.0]
bx = [0.0, 0.0]
y  = [0.0]
by = [1.0]
Enzyme.autodiff(Reverse, f!, Duplicated(x, bx), Duplicated(y, by))

で計算できます。微分は$\bar{x}$に格納されていますので、bxが$\frac{\partial y}{\partial \vec{x}}$です。つまり、yをxの各要素で微分したものです。値は

julia> bx
2-element Vector{Float64}:
 6.0
 2.0

となります。$x = (2,2)$の時の微分値です。微分値は

\frac{\partial [\vec{y}]_1}{\partial [\vec{x}]_1} = 2 [\vec{x}]_1 + [\vec{x}]_2
\frac{\partial [\vec{y}]_1}{\partial [\vec{x}]_2} =  [\vec{x}]_1 

ですから、これに値を代入すると6と2になることがわかりますね。

もちろん、通常の返り値を返す関数

f(x) = x[1]^2+x[1]*x[2]

に対する微分も計算できて、これはZygoteと似た感じに、

gradient(Reverse,f,[2.0,2.0])

で計算できます。

具体例

さて、より複雑な自動微分を行うために、自分で定義した型に関する微分をしてみましょう。型はJuliaでの自動微分を使って行列で微分してみるで使っている

struct Field{Nc,L}
    A::Array{Float64,3}
end

という型に対して自動微分をやってみましょう。このFieldという型は内部に3次元配列を持っています。この三次元配列Aは、1次元上の格子点の上に並んだ行列を表現しているとみなします。つまり、A = zeros(Float64,3,3,10)であれば、10点の格子点の上に$3 \times 3$の零行列が置かれているとみなします。
この型に関する演算を前の記事と同じように

struct Field{Nc,L}
    A::Array{Float64,3}
end

function Field(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end

function identity(Nc,L)
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            A[ic,ic,i] = 1
        end
    end
    return Field{Nc,L}(A)
end

function random_field(Nc,L)
    A = rand(Float64,Nc,Nc,L)
    return Field{Nc,L}(A)
end


function Base.copy(a::Field{Nc,L}) where {Nc,L}
    A = zeros(Float64,Nc,Nc,L)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                A[jc,ic,i] = a.A[jc,ic,i]
            end
        end
    end
    return Field{Nc,L}(A)
end

function Base.zero(a::Field{Nc,L}) where {Nc,L}
    return Field(Nc,L)
end


function Base.:+(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.A,:,:,i)
    end
    return c
end

function Base.:-(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .- view(b.A,:,:,i)
    end
    return c
end


function Base.:*(a::Field{Nc,L},b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.A,:,:,i))
    end
    return c
end

function Base.:*(a::T,b::Field{Nc,L}) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

function Base.:*(b::Field{Nc,L},a::T) where {Nc,L,T <: Number}
    c = zero(b)
    for i=1:L
        c.A[:,:,i] = a*view(b.A,:,:,i)
    end
    return c
end

function Base.display(a::Field{Nc,L}) where {Nc,L}
    for i=1:L
        println("i = $i")
        display(a.A[:,:,i])
        println("\t")
    end
end

function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                c.A[ic,jc,i] = a.A[jc,ic,i]
            end
        end
    end
    return c
end

using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
    c = 0.0
    for i=1:L
        for ic=1:Nc
            c += a.A[ic,ic,i] 
        end
    end
    return c
end

と定義します。行列の掛け算や、トレースの定義をしています。
さて、ここで、スカラー値関数として、

f(A) = \sum_{i=1}^N {\rm Tr} A_i^2 

とします。そして、この微分として、

\frac{\partial f}{\partial [A_i]_{jk}}

を計算します。
これは

function test()
    Nc = 3
    L = 4
    U0 = identity(Nc, L)
    a = random_field(Nc, L)
    b = random_field(Nc, L)
    c = a * b
    display(c)
    d = a + b
    display(d)
    println(tr(c))

    dfda = zero(a)

    ff(a) = tr(a * a)
    fa = ff(a)#calc_f(a, U0)
    println(fa)
    for i = 1:L
        delta = 1e-9
        for ic = 1:Nc
            for jc = 1:Nc
                a_p = copy(a)
                a_p.A[ic, jc, i] += delta
                fad = ff(a_p)#calc_f(a_p, U0)
                dfda.A[ic, jc, i] = (fad - fa) / delta
            end
        end
    end
    display(dfda)
    ba = zero(a)

    Enzyme.autodiff(Reverse, ff, Duplicated(a, ba))
    display(ba)


    display(ba - dfda)

    return
end
test()

で計算できます。
出力結果は、

i = 1
3×3 Matrix{Float64}:
 0.752197  0.555021  0.846438
 1.50753   1.31904   1.47673
 0.868858  0.994287  0.746038

i = 2
3×3 Matrix{Float64}:
 0.57742   1.00897   1.39861
 0.522834  0.796907  0.892183
 0.65779   0.949414  1.08154

i = 3
3×3 Matrix{Float64}:
 0.274848  0.387103  0.815569
 0.98411   0.745404  1.53867
 1.14727   0.871401  1.76196

i = 4
3×3 Matrix{Float64}:
 0.221035  0.281214  0.345309
 0.252666  0.253364  0.274131
 0.486649  0.694152  0.863662

i = 1
3×3 Matrix{Float64}:
 0.728151  1.63787   0.458188
 1.39312   1.39827   1.51125
 1.58845   0.899655  0.819125

i = 2
3×3 Matrix{Float64}:
 0.186181  1.62756   1.13569
 0.240869  0.685574  1.34216
 1.13328   1.4496    1.55671

i = 3
3×3 Matrix{Float64}:
 0.85873   0.528429  1.77271
 1.49914   1.11491   1.5745
 0.610591  1.12549   1.4196

i = 4
3×3 Matrix{Float64}:
 1.00726   0.985933  1.0512
 0.592349  0.563672  0.863005
 0.39699   0.874019  1.46102

9.393422983894713
9.125574636278706
i = 1
3×3 Matrix{Float64}:
 0.380103  1.42496  1.5076
 1.34081   1.97664  0.968907
 0.463963  1.08178  0.319892

i = 2
3×3 Matrix{Float64}:
 0.365031  0.452109  0.394092
 1.98133   0.747372  0.954541
 1.20072   1.1033    1.38854

i = 3
3×3 Matrix{Float64}:
 0.361712  1.26512  1.1646
 0.295017  1.2362   1.68356
 1.72659   1.39135  1.63114

i = 4
3×3 Matrix{Float64}:
 0.177888  0.447013  0.158401
 0.320508  0.145645  0.725429
 0.50504   0.129376  1.76301

i = 1
3×3 Matrix{Float64}:
 0.380103  1.42496  1.5076
 1.34081   1.97664  0.968908
 0.463964  1.08178  0.319892

i = 2
3×3 Matrix{Float64}:
 0.36503  0.45211   0.394092
 1.98133  0.747371  0.954541
 1.20072  1.1033    1.38854

i = 3
3×3 Matrix{Float64}:
 0.361713  1.26512  1.1646
 0.295017  1.2362   1.68356
 1.72659   1.39136  1.63114

i = 4
3×3 Matrix{Float64}:
 0.177888  0.447011  0.158401
 0.320507  0.145645  0.725427
 0.505039  0.129373  1.76301

i = 1
3×3 Matrix{Float64}:
 -8.50562e-8  4.22607e-7  2.08171e-7
  3.12499e-7  1.10879e-6  9.03894e-7
  1.06085e-6  8.70911e-7  1.29601e-7

i = 2
3×3 Matrix{Float64}:
 -1.82106e-7   2.1892e-7   -3.17422e-7
  1.30051e-7  -5.32971e-7  -1.81786e-7
  1.13464e-7   3.61746e-8   7.69233e-7

i = 3
3×3 Matrix{Float64}:
 3.44475e-7  1.54834e-7   3.75643e-7
 1.02411e-7  3.16338e-7  -1.39889e-7
 1.17439e-6  7.72986e-7   7.7299e-7

i = 4
3×3 Matrix{Float64}:
 -1.94369e-7  -1.82672e-6  -4.90448e-7
 -1.01142e-6  -4.05601e-7  -1.7819e-6
 -1.35268e-7  -2.16228e-6  -7.19728e-7

となります。最後の部分は、数値微分と自動微分の比較です。ちゃんと微分できています。
非常に興味深いことに、Enzyme.jlの場合、誤差逆伝播のカスタムな関数を全く定義していません。通常の演算だけを定義しただけで、自動微分ができてしまいました。これは、どうやら、LLVMレベルで自動微分を実現しているために、LLVMでは型などがなくなってしまっているために、同じように自動微分できるそうです。

inplaceな計算の例

function calc_f!(a, b)
    c = b + b'
    a.A .= c.A
    return nothing
end

function trf(a, b)
    calc_f!(a, b)
    return tr(a * a)
end

次は、値を代入するような計算をやってみます。数値微分と自動微分は

function test()
    Nc = 3
    L = 4
    U0 = identity(Nc, L)
    a = random_field(Nc, L)
    b = random_field(Nc, L)

    dfdb = zero(a)
    # ff(a) = tr(a * a)
    ff(a, b) = trf(a, b)
    fa = ff(a, b)#calc_f(a, U0)
    println(fa)
    for i = 1:L
        delta = 1e-9
        for ic = 1:Nc
            for jc = 1:Nc
                b_p = deepcopy(b)
                b_p.A[ic, jc, i] += delta
                fad = ff(a, b_p)#calc_f(a_p, U0)
                dfdb.A[ic, jc, i] = (fad - fa) / delta
            end
        end
    end
    display(dfdb)
    ba = zero(a)
    bb = zero(a)

    Enzyme.autodiff(Reverse, ff, Duplicated(a, ba), Duplicated(b, bb))
    #Enzyme.autodiff(Reverse, ff, Duplicated(a, ba))
    display(bb)
    display(bb - dfdb)

    return
end
test()

で計算できますが、こちらも、カスタムな微分は定義せずに、自動微分が計算できてしまいます。
出力結果は、

35.020562939496834
i = 1
3×3 Matrix{Float64}:
 2.97997   3.95193   0.942734
 3.95193   0.241478  1.76242
 0.942734  1.76242   1.04739

i = 2
3×3 Matrix{Float64}:
 3.37172  4.43166  2.59798
 4.43166  3.89203  3.51522
 2.59798  3.51522  1.43856

i = 3
3×3 Matrix{Float64}:
 4.16045   0.156319  3.17209
 0.156319  6.98827   3.78152
 3.17209   3.78152   1.67992

i = 4
3×3 Matrix{Float64}:
 7.58839  3.04797  4.93254
 3.04797  4.90829  6.70244
 4.93254  6.70244  6.96344

i = 1
3×3 Matrix{Float64}:
 2.97998   3.95193   0.942736
 3.95193   0.241484  1.76242
 0.942736  1.76242   1.04739

i = 2
3×3 Matrix{Float64}:
 3.37173  4.43166  2.59798
 4.43166  3.89203  3.51522
 2.59798  3.51522  1.43856

i = 3
3×3 Matrix{Float64}:
 4.16045   0.156319  3.17209
 0.156319  6.98827   3.78153
 3.17209   3.78153   1.67992

i = 4
3×3 Matrix{Float64}:
 7.58839  3.04797  4.93254
 3.04797  4.90829  6.70244
 4.93254  6.70244  6.96344

i = 1
3×3 Matrix{Float64}:
 2.92204e-6   2.71992e-6   2.15861e-6
 2.71992e-6   5.58389e-6  -5.67338e-7
 2.15861e-6  -5.67338e-7   2.53843e-6

i = 2
3×3 Matrix{Float64}:
  5.06033e-6   1.0123e-6   -1.86649e-7
  1.0123e-6   -1.19775e-6  -3.30432e-7
 -1.86649e-7  -3.30432e-7   6.82783e-7

i = 3
3×3 Matrix{Float64}:
  1.87007e-6  -8.97732e-7  2.81788e-6
 -8.97732e-7   5.56427e-7  2.66112e-6
  2.81788e-6   2.66112e-6  7.2291e-8

i = 4
3×3 Matrix{Float64}:
 2.38505e-6  3.87564e-7  1.10566e-6
 3.87564e-7  5.77211e-7  7.78424e-7
 1.10566e-6  7.78424e-7  2.81158e-6

です。ちゃんと数値微分と自動微分があっていますね。

カスタム微分

Enzyme.jlはForwardとReverseの両方の微分が可能です。ここでは、誤差逆伝播を使うReverseの方に着目します。
カスタム微分ですが、公式のドキュメントにありますように、まずカスタムを作らずに動かして、ちゃんと動かなかったあるいは予想するものを違う結果が出た場合や、手計算ですでに微分の式が出ている時など、にカスタム微分を実装するようです。上で見てきましたように、Zygoteでは実装する必要があった独自型の微分も、Enzymeの場合LLVMレベルでの自動微分によって自動的に行うことができていますので、独自型だからといってカスタム微分を定義する必要はないようです。

ドキュメントの例

ドキュメントにある例をみていきます。
関数としては、

function f!(y, x)
    y .= x.^2
    return sum(y)
end

を考えます。入力の引数yは関数の中で値を代入されていますので、ドキュメントではfですが、ここではf!という名前にしています。
この関数を誤差逆伝播で計算する場合には、関数fがxとyにも依存していることに注意して、

\frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial f} \frac{\partial f}{\partial x_i} + \sum_j \frac{\partial l}{\partial y_j} \frac{\partial y_j}{\partial x_i} = \frac{\partial l}{\partial f} \frac{\partial f}{\partial x_i} + \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial x_i}

となります。yに関する微分$\frac{\partial l}{\partial y_i}$は、ゼロです。なぜなら、f!(y, x)のyを微小変化させても、yはxの二乗で上書きされますので、微小変化の差である微分はゼロとなってしまうからです。

まず、通常のEnzymeの自動微分を使うと、

function f!(y, x)
    y .= x .^ 2
    return sum(y)
end

using Enzyme
x = [3.0, 1.0]
dx = [0.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]

g!(y, x) = f!(y, x)^2 # function to differentiate

autodiff(Reverse, g!, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
println(dx)
println(dy)

となり、出力結果は、

[120.0, 40.0]
[0.0, 0.0]

となります。

次に、カスタム微分を定義してみましょう。
カスタム微分を定義するためには、using Enzymeのほかに、

import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

のようにいくつかのパッケージをインポートします。誤差逆伝播を使うReverseモードでは、二つの関数を定義する必要があるようです。一つ目は

function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f!)}, ::Type{<:Active},
                          y::Duplicated, x::Duplicated)
    println("In custom augmented primal rule.")
    # Compute primal
    if needs_primal(config)
        primal = func.val(y.val, x.val)
    else
        y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if overwritten(config)[3]
        tape = copy(x.val)
    else
        tape = nothing
    end
    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

です。Enzymeにはいくつか見慣れない概念がありますので、そちらを解説しつつ、上の関数を解説します。
まず、Enzymeには変数や関数にアノーテション(注釈)をつけることで、微分が存在するかどうか、などを伝えます。

  • Const この型は微分を持ちません。ここでは、Const{typeof(f!)}とありますが、これは、関数f!の微分がないことを示します。ドキュメントによるとクロージャの場合は微分がある場合があるようです
  • Active この型は微分が存在する場合に使われるようです。格納される値には微分の値が入ったりするようです(まだ完全に理解しておらず。)。
  • Duplicated この型は通常Duplicated(x, dx)のような使われ方をしており、値と微分が格納されます。特に、微分の値を配列としてあらかじめ確保している時に使われるようです。
    の三つがあります。
    augmented_primalという関数は、微分の値だけではなくその関数の値そのものが必要かどうか、というところを見ています。関数の値が必要なときはprimal = func.val(y.val, x.val)としていますが、これは関数を呼び出して値を計算しています。また、関数の値が必要ない場合には、
y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
primal = nothing

としています。これは、関数を呼んでしまうと値を計算してしまいますが、それはしたくなく、一方でyは更新されるべきものですので、更新されるべきyの計算y.val .= x.val.^2 だけをやっているようです。逆に言えば、値が更新されるような引数がある時にちゃんと定義しておく、というような関数です。次に、誤差逆伝播の式です。

function reverse(config::RevConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape,
                 y::Duplicated, x::Duplicated)
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[3] ? tape : x.val
    # accumulate dret into x's shadow. don't assign!
    x.dval .+= 2 .* xval .* dret.val
    # also accumulate any derivative in y's shadow into x's shadow.
    x.dval .+= 2 .* xval .* y.dval
    make_zero!(y.dval)
    return (nothing, nothing)
end

誤差逆伝播はよくデルタルールと呼ばれていますが、式を再掲すると、

\frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial f} \frac{\partial f}{\partial x_i} + \sum_j \frac{\partial l}{\partial y_j} \frac{\partial y_j}{\partial x_i} = \frac{\partial l}{\partial f} \frac{\partial f}{\partial x_i} + \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial x_i}

となっていますが、具体的に計算すると、

\frac{\partial l}{\partial x_i} =   \frac{\partial l}{\partial f} 2 x_i + \frac{\partial l}{\partial y_i} 2 x_i

となっています。ここで、インプットは、$ \frac{\partial l}{\partial f}$と$\frac{\partial l}{\partial y_i}$の二つです。$\frac{\partial l}{\partial x_i} $は関数内ではx.dvalです。xはDuplicated型ですが、この型はx.valx.dvalが格納されています。
デルタルールは二つの和なので、

x.dval .+= 2 .* xval .* dret.val

x.dval .+= 2 .* xval .* y.dval

となっています。

make_zero!(y.dval)

は、yの微分をゼロにしているところです。この関数内ではyがxの2乗によって上書きされているために、上述したように、微分がゼロになっていル、という事情が反映されています。
また、returnは

return (nothing, nothing)

となっていますが、これは、返り値として、xの微分もyの微分も返ってこないことを意味しています。xの微分もyの微分も、Duplicated型のdvalに格納されています。

これを定義して計算すると、

In custom augmented primal rule.
In custom reverse rule.
[120.0, 40.0]
[0.0, 0.0]

となります。同じ結果になっています。

ウィルティンガー微分

次に、複素数の微分であるウィルティンガー微分について考えます。前の記事はJuliaでの自動微分を使って、ウィルティンガー微分してみる:完全版にあります。

定義

複素数$z$を二つの実数$x,y$で表すと$z = x+iy$と書けますが、この複素数$z$の適当な領域で積分可能な関数$f(z) = u + iv$を考えます(u,vは実数)。この時、偏微分

\displaylines{
\frac{\partial f}{\partial x} = \frac{\partial u}{\partial x} + i \frac{\partial v}{\partial x} \\
\frac{\partial f}{\partial y} = \frac{\partial u}{\partial y} + i \frac{\partial v}{\partial y}
}

を考えることができます。関数$f$の全微分は

df = \frac{\partial f}{\partial x}  dx + \frac{\partial f}{\partial y} dy

と書くとします。ここで、$z = x + iy$, $\bar{z} = x - iy$を導入すると、

\displaylines{
dz = dx + i dy \\
d\bar{z} = dx - i dy
}

から、

\displaylines{
dx = (dz + d\bar{z})/2 \\
dy = (dz - d\bar{z})/2i
}

が得られます。これを用いて全微分を書き直すと、

\displaylines{
df = \frac{\partial f}{\partial x} (dz + d\bar{z})/2 + \frac{\partial f}{\partial y} (dz - d\bar{z})/2i \\
= \frac{\partial f}{\partial z}  dz + \frac{\partial f}{\partial \bar{z}} d\bar{z}
}

となります。ここで、

\displaylines{
\frac{\partial f}{\partial z} = \frac{1}{2} \left(\frac{\partial f}{\partial x}  - i \frac{\partial f}{\partial y}  \right) \\
\frac{\partial f}{\partial \bar{z}} = \frac{1}{2} \left(\frac{\partial f}{\partial x}  + i \frac{\partial f}{\partial y}  \right) 
}

がウィルティンガー微分です。

コード

新しい複素数型を定義して、それに関するウィルティンガー微分の自動微分をしてみましょう。型は

struct ComplexField{T}
    z::T
end

ComplexField(a::T) where {T} = ComplexField{T}(a)

function Base.adjoint(a::ComplexField{T}) where {T}
    return ComplexField{T}(a.z')
end

function Base.:*(a::ComplexField, b::ComplexField)
    return ComplexField(a.z * b.z)
end

function Base.:*(a::T, b::ComplexField) where {T<:Number}
    return ComplexField(a * b.z)
end

function Base.:*(b::ComplexField, a::T) where {T<:Number}
    return ComplexField(a * b.z)
end


function Base.real(a::ComplexField)
    ar = (a.z + a.z') / 2
    return real(ar)
end

function Base.:+(a::ComplexField, b::ComplexField)
    return ComplexField(a.z + b.z)
end


function Base.display(a::ComplexField)
    display(a.z)
end

function numerical_derivative(f, x::ComplexField)
    delta = 1e-6
    xd = ComplexField(x.z + delta)
    fx = f(x)
    fxd = f(xd)

    fg_n = (fxd - fx) / delta

    xd_im = ComplexField(x.z + im * delta)
    fxd_im = f(xd_im)
    fg_n_im = (fxd_im - fx) / delta
    return (fg_n - im * fg_n_im) / 2
end

と定義しておきます。微分する関数は

f(z) = {\rm Re} (z \bar{z} + z + \bar{z} + z^2 \bar{z})

まず、これらを定義した状態で、微分を計算してみます。
数値微分はウィルティンガー微分の定義通りですので、数値微分と比較して正しさを調べます。ですので、

function main()
    a = ComplexField(2 + 3.0im)
    b = ComplexField(4.0 + im)
    c = ComplexField(1 + 2.0im)
    println(a * b)
    println(a * b')
    f2(x) = real(x * x' + x + x' + x * x * x')
    gnu = numerical_derivative(f2, a)
    println("Numerical grad: ", gnu)
    g = gradient(Reverse, f2, a)[1]
    println("Autograd: ", g)
end
main()

を実行すると、

ComplexField{ComplexF64}(5.0 + 14.0im)
ComplexField{ComplexF64}(11.0 + 10.0im)
Numerical grad: 13.5000035044186 - 9.000001501391353im
Autograd: ComplexField{ComplexF64}(27.0 + 18.0im)

となります。合っていませんね。これはある意味当たり前でして、Enzymeにウィルティンガー微分がどのような微分であるかという情報を全く入れていないからです。

それでは、カスタム微分を実装してみます。$z$と$\bar{z}$を独立に扱うところに注意する必要があります。二変数あるということですから、デルタルールは

\displaylines{
 \frac{\partial l}{\partial z}  = \frac{\partial l}{\partial f} \frac{\partial f}{\partial z} + \frac{\partial l}{\partial \bar{f}} \frac{\partial \bar{f}}{\partial z}\\
 \frac{\partial l}{\partial \bar{z}}  =\frac{\partial l}{\partial f} \frac{\partial f}{\partial \bar{z}}
+\frac{\partial l}{\partial \bar{f}} \frac{\partial \bar{f}}{\partial \bar{z}}
}

となります。

例えば、実数部分をとる場合は、

{\rm Re} (z) = \frac{1}{2}(z + \bar{z})

となりますから、この関数は$z$で微分しても$\bar{z}$で微分しても有限の値が出ます。そして、

\begin{align}
\frac{\partial }{\partial z}{\rm Re} (z) =\frac{1}{2} \\
\frac{\partial }{\partial \bar{z}}{\rm Re} (z) =\frac{1}{2} 
\end{align}

となりますから、通常の実数$x$を微分

\frac{\partial }{\partial x}{\rm Re} (x) =  1

とは値が異なっています。これは、$z$と$\bar{z}$を独立に扱っているためです。
実部をとる関数に適用すると、実部をとる関数は当然実数なので、

\displaylines{
 \frac{\partial l}{\partial z}  = \frac{\partial l}{\partial f} \frac{1}{2}  
}

となります。そこで、

function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(Base.real)}, ::Type{<:Active},
    a::Active{<:T1}) where {T1<:ComplexField}
    println("In custom augmented primal rule.")
    # Compute primal
    if needs_primal(config)
        primal = func.val(a.val)
    else
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if overwritten(config)[2]
        tape = copy(a.val)
    else
        tape = nothing
    end
    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::RevConfigWidth{1}, func::Const{typeof(real)}, dret::Active, tape,
    a::Active{T1}) where {T1<:ComplexField}
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    #aval = overwritten(config)[3] ? tape : a.val
    # accumulate dret into x's shadow. don't assign!

    return Tuple{T1}((ComplexField(dret.val / 2 + 0im),))
end

とします。しかし、これではまだ

ComplexField{ComplexF64}(5.0 + 14.0im)
ComplexField{ComplexF64}(11.0 + 10.0im)
Numerical grad: 13.5000035044186 - 9.000001501391353im
In custom augmented primal rule.
In custom reverse rule.
Autograd: ComplexField{ComplexF64}(13.5 + 9.0im)

となってしまって、あっていません。そこで、掛け算に関するデルタルール


function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(*)}, ::Type{<:Active},
    a::Active{<:T1}, b::Active{<:T1}) where {T1<:ComplexField}
    println("In custom augmented primal rule in mult.")
    # Compute primal
    if needs_primal(config)
        primal = func.val(a.val, b.val)
    else
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if overwritten(config)[2]
        tape = copy(a.val)
    else
        tape = nothing
    end
    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::RevConfigWidth{1}, func::Const{typeof(*)}, dret::Active, tape,
    a::Active{T1}, b::Active{<:T1}) where {T1<:ComplexField}
    println("In custom reverse rule in mult.")

    return Tuple{T1,T1}((dret.val * b.val, a.val * dret.val))
end

を実装しました。これで、

ComplexField{ComplexF64}(5.0 + 14.0im)
ComplexField{ComplexF64}(11.0 + 10.0im)
Numerical grad: 13.5000035044186 - 9.000001501391353im
In custom augmented primal rule in mult.
In custom augmented primal rule in mult.
In custom augmented primal rule in mult.
In custom augmented primal rule.
In custom reverse rule.
In custom reverse rule in mult.
In custom reverse rule in mult.
In custom reverse rule in mult.
Autograd: ComplexField{ComplexF64}(13.5 - 9.0im)

ちゃんとウィルティンガー微分が自動微分で計算できました。

pullback

他の自動微分パッケージ、例えばZygoteでは、pullbackを実装します。pullbackは返り値を持っています。この場合、返り値でメモリアロケーションが起きてしまうことになります。Enzymeでは、Duplicatedを使うことで、メモリアロケーションなしにpullbackを計算することができます。例に従ってみていきましょう。
行列Aと行列Bの積を計算することを考えます。

R = A B

このとき、この積のpullbackは、

\begin{align}
\frac{\partial l}{\partial A_{ij}} &= \sum_{kl} \frac{\partial l}{\partial R_{kl}} \frac{\partial R_{kl}}{\partial A_{ij}}  =  \sum_{kl} \frac{\partial l}{\partial R_{kl}} \frac{\partial \sum_{n} A_{kn} B_{nl}}{\partial A_{ij}} \\
&=\sum_{kl} \frac{\partial l}{\partial R_{kl}} \sum_{n} B_{nl} \delta_{ki} \delta_{nj} = \sum_{l} \frac{\partial l}{\partial R_{il}}  B_{jl} = \left[ \frac{\partial l}{\partial R} B^{T} \right]_{ij}
\end{align}

となります。ここで、行列による微分は

\left[ \frac{\partial l}{\partial R} \right]_{ij} \equiv
\frac{\partial l}{\partial R_{ij}}

と定義しています(流儀が二つあるので注意)。この積の計算の関数mymul!

using Enzyme, Random

function mymul!(R, A, B)
    @assert axes(A, 2) == axes(B, 1)
    @inbounds @simd for i in eachindex(R)
        R[i] = 0
    end
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        @inbounds @simd for k in axes(A, 2)
            R[i, j] += A[i, k] * B[k, j]
        end
    end
    nothing
end

と定義します。
pullbackを計算するには、$\frac{\partial l}{\partial R}$が必要です。つまり、$\frac{\partial l}{\partial R}$を引数としています。そこで、

Random.seed!(1234)
A = rand(5, 3)
B = rand(3, 7)

R = zeros(size(A, 1), size(B, 2))
∂z_∂R = rand(size(R)...)  # Some gradient/tangent passed to us
∂z_∂R0 = copyto!(similar(∂z_∂R), ∂z_∂R)  # exact copy for comparison

∂z_∂A = zero(A)
∂z_∂B = zero(B)

のように、∂z_∂Rという値を用意して(Enzyme.jlのFAQのコードではlではなくzを使っていますのでzを使いました)、

Enzyme.autodiff(Reverse, mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))

c = R  A * B &&
    ∂z_∂A  ∂z_∂R0 * B' &&  # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]
    ∂z_∂B  A' * ∂z_∂R0      # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2]
println(c)

計算してみましょう。∂z_∂A ≈ ∂z_∂R0 * B'がpullbackの計算そのものです。つまり、Enzyme.autodiffを行うことで、pullbackが計算されたことになります。

カスタム微分の調査

カスタム微分について、書き方がよくわからない場合には、Enzyme.jlのソースコード内のカスタム微分を見ると理解しやすいかもしれません。特に、primalとは何かなどは、ここをよく見ればわかってきそうです。

6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?