1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

新世代自動微分パッケージEnzymeをJuliaで試す その2:structの微分

Posted at

Juliaでの新しい自動微分パッケージEnzyme.jlについてです。以前新世代自動微分パッケージEnzymeをJuliaで試すで紹介しました。
このパッケージの面白いところは、自分で定義した構造体を使っても微分ができることです。公式ページのWhat types are differentiable?によると、Structの中の実数は微分できます。そして、Structの中にStructを定義するような入れ子構造になっていても、フィールドが実数であれば追跡微分してくれます。それでは、Structのフィールドに実数以外が入っている場合はどうなるのでしょう? それを確認してみます。

環境

  • Julia 1.11.5 (2025-04-14)
  • Enzyme v0.13.35

簡単な例

以下のような簡単なStructを定義します。

using Enzyme
mutable struct Simplestruct
    a::Float64
    b::Int64
end
function Base.:+(a::Simplestruct, b::Simplestruct)
    ca = a.a + b.a
    cb = a.b + b.b
    return Simplestruct(ca, cb)
end
function Base.:-(a::Simplestruct, b::Simplestruct)
    ca = a.a - b.a
    cb = a.b - b.b
    return Simplestruct(ca, cb)
end
function Base.:*(a::Simplestruct, b::Simplestruct)
    ca = a.a * b.a
    cb = a.b * b.b
    return Simplestruct(ca, cb)
end
function Base.:*(a::Float64, b::Simplestruct)
    ca = a * b.a
    cb = a * b.b
    return Simplestruct(ca, cb)
end
function Base.:*(b::Simplestruct, a::Float64)
    ca = a * b.a
    cb = a * b.b
    return Simplestruct(ca, cb)
end

このSimplestructはフィールドaは実数、フィールドbは整数です。このSimplestructを引数とした関数を

function ff(a::Simplestruct, b::Simplestruct)
    d = a * b
    ca = sqrt(a.a) + b.a + d.b
    cb = a.b + b.b
    c = Simplestruct(ca, cb)
    return c.a + c.b + c.a * c.b
end

と定義しておきます。出力は実数です。この関数をSimplestructで微分してみましょう。まず、structの値を微小変化させることで数値微分を定義しておきます。

function calc_dfdA_dfdA(f, A::Simplestruct, B::Simplestruct, eta=1e-4)
    fvalue = f(A, B)
    Ap = Simplestruct(A.a + eta, A.b)
    fpvalue = f(Ap, B)
    dfdA = (fpvalue - fvalue) / eta

    Bp = Simplestruct(B.a + eta, B.b)
    fpvalue = f(A, Bp)
    dfdB = (fpvalue - fvalue) / eta
    return dfdA, dfdB
end

そして、自動微分と比較します。自動微分は

function main()
    A = Simplestruct(1.0, 2)
    B = Simplestruct(3.0, 4)
    C = A + B
    println("A + B = ", C)
    D = A - B
    println("A - B = ", D)
    E = A * B
    println("A * B = ", E)
    fvalue = ff(A, B)
    println("ff(A, B) = ", fvalue)

    dfdA_numerical, dfdB_numerical = calc_dfdA_dfdA(ff, A, B)
    println("dfdA_numerical = ", dfdA_numerical)
    println("dfdB_numerical = ", dfdB_numerical)

    dfA = Simplestruct(0.0, 10)
    println("initial dfA = ", dfA)
    dfB = Simplestruct(0.0, -4)
    println("initial dfB = ", dfB)
    Enzyme.autodiff(Reverse, ff, Duplicated(A, dfA), Duplicated(B, dfB))
    println("dfA = ", dfA)
    println("dfB = ", dfB)

end
main()

のように、autodiffで計算できます。ここで、dfAやdfBは微分された値が格納されます。
出力結果は

A + B = Simplestruct(4.0, 6)
A - B = Simplestruct(-2.0, -2)
A * B = Simplestruct(3.0, 8)
ff(A, B) = 64.89949493661167
dfdA_numerical = 20.999999999986585
dfdB_numerical = 6.999999999948159
initial dfA = Simplestruct(0.0, 10)
initial dfB = Simplestruct(0.0, -4)
dfA = Simplestruct(21.0, 10)
dfB = Simplestruct(7.0, -4)
c = (Simplestruct(21.0, 2), Simplestruct(7.0, 4))
yuki@YukiMacBook-Pro-M2-2 qiita % julia enzymestructsimple.jl
A + B = Simplestruct(4.0, 6)
A - B = Simplestruct(-2.0, -2)
A * B = Simplestruct(3.0, 8)
ff(A, B) = 90.0
dfdA_numerical = 3.499912504310032
dfdB_numerical = 6.999999999948159
initial dfA = Simplestruct(0.0, 10)
initial dfB = Simplestruct(0.0, -4)
dfA = Simplestruct(3.5, 10)
dfB = Simplestruct(7.0, -4)

となります。これを見ると、実数の部分はちゃんと微分されており、整数の場合はdfA,dfBは変更されていないことがわかります。つまり、微分されていません。
自動微分はgradientでも計算できますから、

 c = Enzyme.gradient(Reverse, ff, A, B)
println("c = ", c)

でもできます。
この出力結果は

c = (Simplestruct(3.5, 2), Simplestruct(7.0, 4))

となります。見てわかるように、整数の部分はAとBのフィールドの値をそのまま引き継いでいます。

もう少し複雑な場合

複雑な場合についてもみていきましょう。
今度はStructを

struct Subfield{Nc}
    A::Array{Float64,2}
end
function Subfield(Nc)
    A = zeros(Float64, Nc, Nc)
    return Subfield{Nc}(A)
end
function Subfield(Nc, A)
    return Subfield{Nc}(A)
end


mutable struct Field{Nc,L}
    A::Array{Float64,3}
    value::Float64
    subA::Subfield{Nc}
    ivalue::Int64
end

として、structの中にstructが入っている形になっています。
このFieldを引数として、関数を

function calc_f!(a, b)
    c = b + b'
    a.A .= c.A
    a.value += tr(a) + b.value + c.value
    a.ivalue += 3 * b.ivalue
    return nothing
end

function trf(a, b)
    calc_f!(a, b)
    return tr(a * a) * tr(b) + b.value * (tr(b) + a.ivalue) + tr(b.subA.A * a.subA.A)
end

と定義しておきます。途中でフィールドの値を変えたりやりたい放題しているスカラー値関数です。これを自動微分します。

数値微分との比較をしたコードは

using Enzyme

struct Subfield{Nc}
    A::Array{Float64,2}
end
function Subfield(Nc)
    A = zeros(Float64, Nc, Nc)
    return Subfield{Nc}(A)
end
function Subfield(Nc, A)
    return Subfield{Nc}(A)
end


mutable struct Field{Nc,L}
    A::Array{Float64,3}
    value::Float64
    subA::Subfield{Nc}
    ivalue::Int64
end

function Field(Nc, L)
    A = zeros(Float64, Nc, Nc, L)
    value = 0.0
    subA = Subfield(Nc)
    ivalue = 0.0
    return Field{Nc,L}(A, value, subA, ivalue)
end

function identity(Nc, L)
    A = zeros(Float64, Nc, Nc, L)
    value = 1.0
    for i = 1:L
        for ic = 1:Nc
            A[ic, ic, i] = 1
        end
    end
    subA = Subfield(Nc, rand(Float64, Nc, Nc))
    ivalue = 1.0
    return Field{Nc,L}(A, value, subA, ivalue)
end

function random_field(Nc, L, name="")
    A = rand(Float64, Nc, Nc, L)
    value = rand()
    subB = Subfield(Nc, rand(Float64, Nc, Nc))
    ivalue = rand([-1, 0, 1])
    return Field{Nc,L}(A, value, subB, ivalue)
end


function Base.copy(a::Field{Nc,L}) where {Nc,L}
    A = zeros(Float64, Nc, Nc, L)
    for i = 1:L
        for ic = 1:Nc
            for jc = 1:Nc
                A[jc, ic, i] = a.A[jc, ic, i]
            end
        end
    end
    B = zeros(Float64, Nc, Nc)
    B .= a.A[:, :, 1]
    subA = Subfield(Nc, B)
    return Field{Nc,L}(A, a.value, subA, a.ivalue)
end

function Base.zero(a::Field{Nc,L}) where {Nc,L}
    return Field(Nc, L)
end


function Base.:+(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        c.A[:, :, i] = view(a.A, :, :, i) .+ view(b.A, :, :, i)
    end
    c.value = a.value + b.value
    c.subA.A .= a.subA.A .+ b.subA.A
    c.ivalue = a.ivalue + b.ivalue

    return c
end

function Base.:-(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        c.A[:, :, i] = view(a.A, :, :, i) .- view(b.A, :, :, i)
    end
    c.value = a.value - b.value
    c.subA.A .= a.subA.A .- b.subA.A
    c.ivalue = a.ivalue - b.ivalue
    return c
end


function Base.:*(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        mul!(view(c.A, :, :, i), view(a.A, :, :, i), view(b.A, :, :, i))
    end
    c.value = a.value * b.value
    c.subA.A .= a.subA.A * b.subA.A
    c.ivalue = a.ivalue * b.ivalue
    return c
end

function Base.:*(a::T, b::Field{Nc,L}) where {Nc,L,T<:Number}
    c = zero(b)
    for i = 1:L
        c.A[:, :, i] = a * view(b.A, :, :, i)
    end
    c.value = a * b.value
    c.subA.A .= a * b.subA.A
    c.ivalue = a * b.ivalue
    return c
end

function Base.:*(b::Field{Nc,L}, a::T) where {Nc,L,T<:Number}
    c = zero(b)
    for i = 1:L
        c.A[:, :, i] = a * view(b.A, :, :, i)
    end
    c.value = a * b.value
    c.subA.A .= a * b.subA.A
    c.ivalue = a * b.ivalue
    return c
end

function Base.display(a::Field{Nc,L}) where {Nc,L}
    for i = 1:L
        println("i = $i")
        display(a.A[:, :, i])
        println("\t")
    end
    println("value = ", a.value)
    display(a.subA.A)
    println("ivalue = ", a.ivalue)
end

function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
    c = zero(a)
    for i = 1:L
        for ic = 1:Nc
            for jc = 1:Nc
                c.A[ic, jc, i] = a.A[jc, ic, i]
            end
        end
    end
    c.value = a.value
    for ic = 1:Nc
        for jc = 1:Nc
            c.subA.A[ic, jc] = a.subA.A[jc, ic]
        end
    end
    c.ivalue = a.ivalue

    return c
end

using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
    c = 0.0
    for i = 1:L
        for ic = 1:Nc
            c += a.A[ic, ic, i]
        end
    end
    return c
end

function calc_f!(a, b)
    c = b + b'
    a.A .= c.A
    a.value += tr(a) + b.value + c.value
    a.ivalue += 3 * b.ivalue
    return nothing
end

function trf(a, b)
    calc_f!(a, b)
    return tr(a * a) * tr(b) + b.value * (tr(b) + a.ivalue) + tr(b.subA.A * a.subA.A)
end

function calc_f(x, d)
    c = tr(x * x + d)
    return c
end




function test()
    Nc = 3
    L = 4
    U0 = identity(Nc, L)
    a = random_field(Nc, L)
    b = random_field(Nc, L)

    dfdb = zero(a)
    # ff(a) = tr(a * a)
    ff(a, b) = trf(a, b)
    fa = ff(a, b)#calc_f(a, U0)
    println(fa)
    delta = 1e-9
    for i = 1:L

        for ic = 1:Nc
            for jc = 1:Nc
                b_p = deepcopy(b)
                b_p.A[ic, jc, i] += delta
                fad = ff(a, b_p)#calc_f(a_p, U0)
                dfdb.A[ic, jc, i] = (fad - fa) / delta
            end
        end
    end
    b_p = deepcopy(b)
    b_p.value += delta
    fad = ff(a, b_p)#calc_f(a_p, U0)
    println(fad)
    dfdb.value = (fad - fa) / delta

    for ic = 1:Nc
        for jc = 1:Nc
            b_p = deepcopy(b)
            b_p.subA.A[ic, jc] += delta
            fad = ff(a, b_p)#calc_f(a_p, U0)
            dfdb.subA.A[ic, jc] = (fad - fa) / delta
        end
    end

    display(dfdb)
    ba = zero(a)
    bb = zero(a)

    Enzyme.autodiff(Reverse, ff, Duplicated(a, ba), Duplicated(b, bb))
    #Enzyme.autodiff(Reverse, ff, Duplicated(a, ba))
    display(bb)
    display(bb - dfdb)

    return
end
test()

となります。

出力結果は、

367.49318813675427
367.49318813675427
i = 1
3×3 Matrix{Float64}:
 55.7968  14.177   32.7866
 14.177   73.0769  32.959
 32.7867  32.959   79.7097
	
i = 2
3×3 Matrix{Float64}:
 107.883    39.0587  33.0684
  39.0587  100.565   30.5101
  33.0684   30.5101  99.404
	
i = 3
3×3 Matrix{Float64}:
 70.9697   56.4435   28.2004
 56.4435  105.833    15.1688
 28.2004   15.1688  104.882
	
i = 4
3×3 Matrix{Float64}:
 81.0351  31.0974  22.0215
 31.0974  92.8779  13.3355
 22.0215  13.3355  62.0721
	
value = 0.0
3×3 Matrix{Float64}:
 0.181956   0.0532623  0.67331
 0.0179057  0.428429   0.625562
 0.110049   0.0352429  0.782279
ivalue = 0
i = 1
3×3 Matrix{Float64}:
 55.7967  14.177   32.7866
 14.177   73.0769  32.9589
 32.7866  32.9589  79.7096
	
i = 2
3×3 Matrix{Float64}:
 107.883    39.0587  33.0684
  39.0587  100.565   30.5101
  33.0684   30.5101  99.404
	
i = 3
3×3 Matrix{Float64}:
 70.9697   56.4435   28.2003
 56.4435  105.833    15.1688
 28.2003   15.1688  104.882
	
i = 4
3×3 Matrix{Float64}:
 81.035   31.0974  22.0214
 31.0974  92.8779  13.3354
 22.0214  13.3354  62.0721
	
value = 0.0
3×3 Matrix{Float64}:
 0.181979   0.0532863  0.673329
 0.0179384  0.428452   0.625592
 0.110059   0.0352599  0.782314
ivalue = 0
i = 1
3×3 Matrix{Float64}:
 -2.17127e-5  -8.46588e-6   7.72882e-6
 -8.46588e-6  -1.33173e-5  -2.74174e-5
 -4.91146e-5  -2.74174e-5  -1.01197e-5
	
i = 2
3×3 Matrix{Float64}:
 -3.76411e-5  -2.36041e-6   7.1748e-6
 -2.36041e-6  -1.88383e-5   3.59186e-5
  7.1748e-6    3.59186e-5  -2.26093e-5
	
i = 3
3×3 Matrix{Float64}:
  1.17163e-5   6.54263e-6  -2.97498e-5
  6.54263e-6  -1.39499e-5  -4.83976e-5
 -2.97498e-5  -4.83976e-5  -9.02612e-5
	
i = 4
3×3 Matrix{Float64}:
 -5.65484e-5  -3.2358e-5   -3.27307e-5
 -3.2358e-5   -2.50569e-5  -7.53051e-5
 -3.27307e-5  -7.53051e-5   1.48167e-5
	
value = 0.0
3×3 Matrix{Float64}:
 2.33452e-5  2.40411e-5  1.87145e-5
 3.27268e-5  2.2835e-5   3.01355e-5
 1.03255e-5  1.69957e-5  3.47077e-5
ivalue = 0
yuki@YukiMacBook-Pro-M2-2 qiita % julia enzymestruct.jl 
287.0472017667113
287.04720177260765
i = 1
3×3 Matrix{Float64}:
 71.7799  33.7989  32.0543
 33.7989  73.7601  36.8872
 32.0543  36.8872  64.7643
	
i = 2
3×3 Matrix{Float64}:
 66.9696  28.3479  22.2473
 28.3479  87.6349  21.5292
 22.2473  21.5292  63.2297
	
i = 3
3×3 Matrix{Float64}:
 62.7962  35.213   22.931
 35.213   56.9894  23.5439
 22.931   23.5439  91.0886
	
i = 4
3×3 Matrix{Float64}:
 90.9276  23.5349  22.0592
 23.5349  71.3294  25.4303
 22.0592  25.4303  56.9655
	
value = 5.896367838431615
3×3 Matrix{Float64}:
 0.262787  0.782734  0.954515
 0.212708  0.298542  0.382784
 0.514717  0.626756  0.656371
ivalue = 0
i = 1
3×3 Matrix{Float64}:
 71.78    33.7989  32.0543
 33.7989  73.7602  36.8872
 32.0543  36.8872  64.7643
	
i = 2
3×3 Matrix{Float64}:
 66.9696  28.3478  22.2473
 28.3478  87.6349  21.5291
 22.2473  21.5291  63.2297
	
i = 3
3×3 Matrix{Float64}:
 62.7963  35.213   22.931
 35.213   56.9894  23.5439
 22.931   23.5439  91.0886
	
i = 4
3×3 Matrix{Float64}:
 90.9277  23.5349  22.0592
 23.5349  71.3294  25.4303
 22.0592  25.4303  56.9655
	
value = 5.896376246674556
3×3 Matrix{Float64}:
 0.26281   0.782738  0.954528
 0.212709  0.298533  0.382799
 0.514716  0.626763  0.656369
ivalue = 0
i = 1
3×3 Matrix{Float64}:
  4.05817e-5  -2.26194e-5  -3.74406e-7
 -2.26194e-5   4.73619e-5  -4.32907e-5
 -3.74406e-7  -4.32907e-5   1.74079e-5
	
i = 2
3×3 Matrix{Float64}:
  4.84749e-5  -3.57308e-5   6.58313e-6
 -3.57308e-5   8.53416e-6  -3.0923e-5
  6.58313e-6  -3.0923e-5   -4.80622e-6
	
i = 3
3×3 Matrix{Float64}:
  3.08882e-5  3.30059e-5  -2.79435e-5
  3.30059e-5  3.55657e-5   3.44012e-5
 -2.79435e-5  3.44012e-5  -1.1738e-5
	
i = 4
3×3 Matrix{Float64}:
 2.93962e-5  4.67822e-5  9.08425e-6
 4.67822e-5  9.91826e-6  2.66239e-5
 9.08425e-6  2.66239e-5  2.18415e-5
	
value = 8.408242941015942e-6
3×3 Matrix{Float64}:
  2.24301e-5   3.84649e-6   1.36945e-5
  7.05173e-7  -8.70779e-6   1.49238e-5
 -1.04753e-6   7.29251e-6  -1.51115e-6
ivalue = 0

となります。
最後の出力が型がFieldの「スカラー値関数のField微分」となっています。それぞれのフィールドにそれぞれの微分が入っています。実数となっているフィールドは微分の値が入っていることがわかります。整数のフィールドは0です。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?