Juliaでの新しい自動微分パッケージEnzyme.jlについてです。以前新世代自動微分パッケージEnzymeをJuliaで試すで紹介しました。
このパッケージの面白いところは、自分で定義した構造体を使っても微分ができることです。公式ページのWhat types are differentiable?によると、Structの中の実数は微分できます。そして、Structの中にStructを定義するような入れ子構造になっていても、フィールドが実数であれば追跡微分してくれます。それでは、Structのフィールドに実数以外が入っている場合はどうなるのでしょう? それを確認してみます。
環境
- Julia 1.11.5 (2025-04-14)
- Enzyme v0.13.35
簡単な例
以下のような簡単なStructを定義します。
using Enzyme
mutable struct Simplestruct
a::Float64
b::Int64
end
function Base.:+(a::Simplestruct, b::Simplestruct)
ca = a.a + b.a
cb = a.b + b.b
return Simplestruct(ca, cb)
end
function Base.:-(a::Simplestruct, b::Simplestruct)
ca = a.a - b.a
cb = a.b - b.b
return Simplestruct(ca, cb)
end
function Base.:*(a::Simplestruct, b::Simplestruct)
ca = a.a * b.a
cb = a.b * b.b
return Simplestruct(ca, cb)
end
function Base.:*(a::Float64, b::Simplestruct)
ca = a * b.a
cb = a * b.b
return Simplestruct(ca, cb)
end
function Base.:*(b::Simplestruct, a::Float64)
ca = a * b.a
cb = a * b.b
return Simplestruct(ca, cb)
end
このSimplestructはフィールドaは実数、フィールドbは整数です。このSimplestructを引数とした関数を
function ff(a::Simplestruct, b::Simplestruct)
d = a * b
ca = sqrt(a.a) + b.a + d.b
cb = a.b + b.b
c = Simplestruct(ca, cb)
return c.a + c.b + c.a * c.b
end
と定義しておきます。出力は実数です。この関数をSimplestructで微分してみましょう。まず、structの値を微小変化させることで数値微分を定義しておきます。
function calc_dfdA_dfdA(f, A::Simplestruct, B::Simplestruct, eta=1e-4)
fvalue = f(A, B)
Ap = Simplestruct(A.a + eta, A.b)
fpvalue = f(Ap, B)
dfdA = (fpvalue - fvalue) / eta
Bp = Simplestruct(B.a + eta, B.b)
fpvalue = f(A, Bp)
dfdB = (fpvalue - fvalue) / eta
return dfdA, dfdB
end
そして、自動微分と比較します。自動微分は
function main()
A = Simplestruct(1.0, 2)
B = Simplestruct(3.0, 4)
C = A + B
println("A + B = ", C)
D = A - B
println("A - B = ", D)
E = A * B
println("A * B = ", E)
fvalue = ff(A, B)
println("ff(A, B) = ", fvalue)
dfdA_numerical, dfdB_numerical = calc_dfdA_dfdA(ff, A, B)
println("dfdA_numerical = ", dfdA_numerical)
println("dfdB_numerical = ", dfdB_numerical)
dfA = Simplestruct(0.0, 10)
println("initial dfA = ", dfA)
dfB = Simplestruct(0.0, -4)
println("initial dfB = ", dfB)
Enzyme.autodiff(Reverse, ff, Duplicated(A, dfA), Duplicated(B, dfB))
println("dfA = ", dfA)
println("dfB = ", dfB)
end
main()
のように、autodiffで計算できます。ここで、dfAやdfBは微分された値が格納されます。
出力結果は
A + B = Simplestruct(4.0, 6)
A - B = Simplestruct(-2.0, -2)
A * B = Simplestruct(3.0, 8)
ff(A, B) = 64.89949493661167
dfdA_numerical = 20.999999999986585
dfdB_numerical = 6.999999999948159
initial dfA = Simplestruct(0.0, 10)
initial dfB = Simplestruct(0.0, -4)
dfA = Simplestruct(21.0, 10)
dfB = Simplestruct(7.0, -4)
c = (Simplestruct(21.0, 2), Simplestruct(7.0, 4))
yuki@YukiMacBook-Pro-M2-2 qiita % julia enzymestructsimple.jl
A + B = Simplestruct(4.0, 6)
A - B = Simplestruct(-2.0, -2)
A * B = Simplestruct(3.0, 8)
ff(A, B) = 90.0
dfdA_numerical = 3.499912504310032
dfdB_numerical = 6.999999999948159
initial dfA = Simplestruct(0.0, 10)
initial dfB = Simplestruct(0.0, -4)
dfA = Simplestruct(3.5, 10)
dfB = Simplestruct(7.0, -4)
となります。これを見ると、実数の部分はちゃんと微分されており、整数の場合はdfA,dfBは変更されていないことがわかります。つまり、微分されていません。
自動微分はgradientでも計算できますから、
c = Enzyme.gradient(Reverse, ff, A, B)
println("c = ", c)
でもできます。
この出力結果は
c = (Simplestruct(3.5, 2), Simplestruct(7.0, 4))
となります。見てわかるように、整数の部分はAとBのフィールドの値をそのまま引き継いでいます。
もう少し複雑な場合
複雑な場合についてもみていきましょう。
今度はStructを
struct Subfield{Nc}
A::Array{Float64,2}
end
function Subfield(Nc)
A = zeros(Float64, Nc, Nc)
return Subfield{Nc}(A)
end
function Subfield(Nc, A)
return Subfield{Nc}(A)
end
mutable struct Field{Nc,L}
A::Array{Float64,3}
value::Float64
subA::Subfield{Nc}
ivalue::Int64
end
として、structの中にstructが入っている形になっています。
このFieldを引数として、関数を
function calc_f!(a, b)
c = b + b'
a.A .= c.A
a.value += tr(a) + b.value + c.value
a.ivalue += 3 * b.ivalue
return nothing
end
function trf(a, b)
calc_f!(a, b)
return tr(a * a) * tr(b) + b.value * (tr(b) + a.ivalue) + tr(b.subA.A * a.subA.A)
end
と定義しておきます。途中でフィールドの値を変えたりやりたい放題しているスカラー値関数です。これを自動微分します。
数値微分との比較をしたコードは
using Enzyme
struct Subfield{Nc}
A::Array{Float64,2}
end
function Subfield(Nc)
A = zeros(Float64, Nc, Nc)
return Subfield{Nc}(A)
end
function Subfield(Nc, A)
return Subfield{Nc}(A)
end
mutable struct Field{Nc,L}
A::Array{Float64,3}
value::Float64
subA::Subfield{Nc}
ivalue::Int64
end
function Field(Nc, L)
A = zeros(Float64, Nc, Nc, L)
value = 0.0
subA = Subfield(Nc)
ivalue = 0.0
return Field{Nc,L}(A, value, subA, ivalue)
end
function identity(Nc, L)
A = zeros(Float64, Nc, Nc, L)
value = 1.0
for i = 1:L
for ic = 1:Nc
A[ic, ic, i] = 1
end
end
subA = Subfield(Nc, rand(Float64, Nc, Nc))
ivalue = 1.0
return Field{Nc,L}(A, value, subA, ivalue)
end
function random_field(Nc, L, name="")
A = rand(Float64, Nc, Nc, L)
value = rand()
subB = Subfield(Nc, rand(Float64, Nc, Nc))
ivalue = rand([-1, 0, 1])
return Field{Nc,L}(A, value, subB, ivalue)
end
function Base.copy(a::Field{Nc,L}) where {Nc,L}
A = zeros(Float64, Nc, Nc, L)
for i = 1:L
for ic = 1:Nc
for jc = 1:Nc
A[jc, ic, i] = a.A[jc, ic, i]
end
end
end
B = zeros(Float64, Nc, Nc)
B .= a.A[:, :, 1]
subA = Subfield(Nc, B)
return Field{Nc,L}(A, a.value, subA, a.ivalue)
end
function Base.zero(a::Field{Nc,L}) where {Nc,L}
return Field(Nc, L)
end
function Base.:+(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i = 1:L
c.A[:, :, i] = view(a.A, :, :, i) .+ view(b.A, :, :, i)
end
c.value = a.value + b.value
c.subA.A .= a.subA.A .+ b.subA.A
c.ivalue = a.ivalue + b.ivalue
return c
end
function Base.:-(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i = 1:L
c.A[:, :, i] = view(a.A, :, :, i) .- view(b.A, :, :, i)
end
c.value = a.value - b.value
c.subA.A .= a.subA.A .- b.subA.A
c.ivalue = a.ivalue - b.ivalue
return c
end
function Base.:*(a::Field{Nc,L}, b::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i = 1:L
mul!(view(c.A, :, :, i), view(a.A, :, :, i), view(b.A, :, :, i))
end
c.value = a.value * b.value
c.subA.A .= a.subA.A * b.subA.A
c.ivalue = a.ivalue * b.ivalue
return c
end
function Base.:*(a::T, b::Field{Nc,L}) where {Nc,L,T<:Number}
c = zero(b)
for i = 1:L
c.A[:, :, i] = a * view(b.A, :, :, i)
end
c.value = a * b.value
c.subA.A .= a * b.subA.A
c.ivalue = a * b.ivalue
return c
end
function Base.:*(b::Field{Nc,L}, a::T) where {Nc,L,T<:Number}
c = zero(b)
for i = 1:L
c.A[:, :, i] = a * view(b.A, :, :, i)
end
c.value = a * b.value
c.subA.A .= a * b.subA.A
c.ivalue = a * b.ivalue
return c
end
function Base.display(a::Field{Nc,L}) where {Nc,L}
for i = 1:L
println("i = $i")
display(a.A[:, :, i])
println("\t")
end
println("value = ", a.value)
display(a.subA.A)
println("ivalue = ", a.ivalue)
end
function Base.adjoint(a::Field{Nc,L}) where {Nc,L}
c = zero(a)
for i = 1:L
for ic = 1:Nc
for jc = 1:Nc
c.A[ic, jc, i] = a.A[jc, ic, i]
end
end
end
c.value = a.value
for ic = 1:Nc
for jc = 1:Nc
c.subA.A[ic, jc] = a.subA.A[jc, ic]
end
end
c.ivalue = a.ivalue
return c
end
using LinearAlgebra
function LinearAlgebra.tr(a::Field{Nc,L}) where {Nc,L}
c = 0.0
for i = 1:L
for ic = 1:Nc
c += a.A[ic, ic, i]
end
end
return c
end
function calc_f!(a, b)
c = b + b'
a.A .= c.A
a.value += tr(a) + b.value + c.value
a.ivalue += 3 * b.ivalue
return nothing
end
function trf(a, b)
calc_f!(a, b)
return tr(a * a) * tr(b) + b.value * (tr(b) + a.ivalue) + tr(b.subA.A * a.subA.A)
end
function calc_f(x, d)
c = tr(x * x + d)
return c
end
function test()
Nc = 3
L = 4
U0 = identity(Nc, L)
a = random_field(Nc, L)
b = random_field(Nc, L)
dfdb = zero(a)
# ff(a) = tr(a * a)
ff(a, b) = trf(a, b)
fa = ff(a, b)#calc_f(a, U0)
println(fa)
delta = 1e-9
for i = 1:L
for ic = 1:Nc
for jc = 1:Nc
b_p = deepcopy(b)
b_p.A[ic, jc, i] += delta
fad = ff(a, b_p)#calc_f(a_p, U0)
dfdb.A[ic, jc, i] = (fad - fa) / delta
end
end
end
b_p = deepcopy(b)
b_p.value += delta
fad = ff(a, b_p)#calc_f(a_p, U0)
println(fad)
dfdb.value = (fad - fa) / delta
for ic = 1:Nc
for jc = 1:Nc
b_p = deepcopy(b)
b_p.subA.A[ic, jc] += delta
fad = ff(a, b_p)#calc_f(a_p, U0)
dfdb.subA.A[ic, jc] = (fad - fa) / delta
end
end
display(dfdb)
ba = zero(a)
bb = zero(a)
Enzyme.autodiff(Reverse, ff, Duplicated(a, ba), Duplicated(b, bb))
#Enzyme.autodiff(Reverse, ff, Duplicated(a, ba))
display(bb)
display(bb - dfdb)
return
end
test()
となります。
出力結果は、
367.49318813675427
367.49318813675427
i = 1
3×3 Matrix{Float64}:
55.7968 14.177 32.7866
14.177 73.0769 32.959
32.7867 32.959 79.7097
i = 2
3×3 Matrix{Float64}:
107.883 39.0587 33.0684
39.0587 100.565 30.5101
33.0684 30.5101 99.404
i = 3
3×3 Matrix{Float64}:
70.9697 56.4435 28.2004
56.4435 105.833 15.1688
28.2004 15.1688 104.882
i = 4
3×3 Matrix{Float64}:
81.0351 31.0974 22.0215
31.0974 92.8779 13.3355
22.0215 13.3355 62.0721
value = 0.0
3×3 Matrix{Float64}:
0.181956 0.0532623 0.67331
0.0179057 0.428429 0.625562
0.110049 0.0352429 0.782279
ivalue = 0
i = 1
3×3 Matrix{Float64}:
55.7967 14.177 32.7866
14.177 73.0769 32.9589
32.7866 32.9589 79.7096
i = 2
3×3 Matrix{Float64}:
107.883 39.0587 33.0684
39.0587 100.565 30.5101
33.0684 30.5101 99.404
i = 3
3×3 Matrix{Float64}:
70.9697 56.4435 28.2003
56.4435 105.833 15.1688
28.2003 15.1688 104.882
i = 4
3×3 Matrix{Float64}:
81.035 31.0974 22.0214
31.0974 92.8779 13.3354
22.0214 13.3354 62.0721
value = 0.0
3×3 Matrix{Float64}:
0.181979 0.0532863 0.673329
0.0179384 0.428452 0.625592
0.110059 0.0352599 0.782314
ivalue = 0
i = 1
3×3 Matrix{Float64}:
-2.17127e-5 -8.46588e-6 7.72882e-6
-8.46588e-6 -1.33173e-5 -2.74174e-5
-4.91146e-5 -2.74174e-5 -1.01197e-5
i = 2
3×3 Matrix{Float64}:
-3.76411e-5 -2.36041e-6 7.1748e-6
-2.36041e-6 -1.88383e-5 3.59186e-5
7.1748e-6 3.59186e-5 -2.26093e-5
i = 3
3×3 Matrix{Float64}:
1.17163e-5 6.54263e-6 -2.97498e-5
6.54263e-6 -1.39499e-5 -4.83976e-5
-2.97498e-5 -4.83976e-5 -9.02612e-5
i = 4
3×3 Matrix{Float64}:
-5.65484e-5 -3.2358e-5 -3.27307e-5
-3.2358e-5 -2.50569e-5 -7.53051e-5
-3.27307e-5 -7.53051e-5 1.48167e-5
value = 0.0
3×3 Matrix{Float64}:
2.33452e-5 2.40411e-5 1.87145e-5
3.27268e-5 2.2835e-5 3.01355e-5
1.03255e-5 1.69957e-5 3.47077e-5
ivalue = 0
yuki@YukiMacBook-Pro-M2-2 qiita % julia enzymestruct.jl
287.0472017667113
287.04720177260765
i = 1
3×3 Matrix{Float64}:
71.7799 33.7989 32.0543
33.7989 73.7601 36.8872
32.0543 36.8872 64.7643
i = 2
3×3 Matrix{Float64}:
66.9696 28.3479 22.2473
28.3479 87.6349 21.5292
22.2473 21.5292 63.2297
i = 3
3×3 Matrix{Float64}:
62.7962 35.213 22.931
35.213 56.9894 23.5439
22.931 23.5439 91.0886
i = 4
3×3 Matrix{Float64}:
90.9276 23.5349 22.0592
23.5349 71.3294 25.4303
22.0592 25.4303 56.9655
value = 5.896367838431615
3×3 Matrix{Float64}:
0.262787 0.782734 0.954515
0.212708 0.298542 0.382784
0.514717 0.626756 0.656371
ivalue = 0
i = 1
3×3 Matrix{Float64}:
71.78 33.7989 32.0543
33.7989 73.7602 36.8872
32.0543 36.8872 64.7643
i = 2
3×3 Matrix{Float64}:
66.9696 28.3478 22.2473
28.3478 87.6349 21.5291
22.2473 21.5291 63.2297
i = 3
3×3 Matrix{Float64}:
62.7963 35.213 22.931
35.213 56.9894 23.5439
22.931 23.5439 91.0886
i = 4
3×3 Matrix{Float64}:
90.9277 23.5349 22.0592
23.5349 71.3294 25.4303
22.0592 25.4303 56.9655
value = 5.896376246674556
3×3 Matrix{Float64}:
0.26281 0.782738 0.954528
0.212709 0.298533 0.382799
0.514716 0.626763 0.656369
ivalue = 0
i = 1
3×3 Matrix{Float64}:
4.05817e-5 -2.26194e-5 -3.74406e-7
-2.26194e-5 4.73619e-5 -4.32907e-5
-3.74406e-7 -4.32907e-5 1.74079e-5
i = 2
3×3 Matrix{Float64}:
4.84749e-5 -3.57308e-5 6.58313e-6
-3.57308e-5 8.53416e-6 -3.0923e-5
6.58313e-6 -3.0923e-5 -4.80622e-6
i = 3
3×3 Matrix{Float64}:
3.08882e-5 3.30059e-5 -2.79435e-5
3.30059e-5 3.55657e-5 3.44012e-5
-2.79435e-5 3.44012e-5 -1.1738e-5
i = 4
3×3 Matrix{Float64}:
2.93962e-5 4.67822e-5 9.08425e-6
4.67822e-5 9.91826e-6 2.66239e-5
9.08425e-6 2.66239e-5 2.18415e-5
value = 8.408242941015942e-6
3×3 Matrix{Float64}:
2.24301e-5 3.84649e-6 1.36945e-5
7.05173e-7 -8.70779e-6 1.49238e-5
-1.04753e-6 7.29251e-6 -1.51115e-6
ivalue = 0
となります。
最後の出力が型がFieldの「スカラー値関数のField微分」となっています。それぞれのフィールドにそれぞれの微分が入っています。実数となっているフィールドは微分の値が入っていることがわかります。整数のフィールドは0です。