LoginSignup
4
1

More than 1 year has passed since last update.

Juliaでの自動微分を使って行列を引数にした複雑な関数を行列で微分してみる

Last updated at Posted at 2022-09-07

Juliaでの自動微分を使って行列で微分してみる
の続編です。さらに複雑な関数を微分できるようにしたいので、それをやってみたいと思います。

計算したい量

前回は、

f(\{ U \}) = \sum_{k=1}^{L}{\rm tr} (U_k U_k + I)

という量を計算し、その微分:

\frac{\partial f(\{ U \})}{\partial [U_{k}]_{ij}} 

を自動微分で計算しました。ここで、$U_k$は$k$番目の格子点の上に定義された$n \times n$の行列で、コード上では

struct Field{Nc,L}
    A::Array{Float64,3}
end

と定義されています。
次は、

f(\{ U \}) = \sum_{k=1}^{L}{\rm tr} (U_k U_{k+1} U_{k+2})

を計算し、その行列微分を自動で計算してみましょう。なお、$U_{L+1} = U_1$及び$U_{L+2} = U_2$とします。

シフトした量の定義

$U_{k+1}$という行列をどうやって扱うか、ですが、これはそれ専用の型を定義してしまいましょう。

struct Shifted_Field{Nc,L,shift}
    F::Field{Nc,L}
    function Shifted_Field(F::Field{Nc,L},shift) where {Nc,L}
        shiftedF = apply_shift(F,shift)
        return new{Nc,L,shift}(shiftedF)
    end
end

function get_shiftedi(i,L,shift)
    shiftedi = i + shift
    while shiftedi < 1 || shiftedi > L
        shiftedi += ifelse(shiftedi > L,-L,0)
        shiftedi += ifelse(shiftedi < 1,L,0)
    end
    return shiftedi
end

function apply_shift(F::Field{Nc,L},shift) where {Nc,L}
    Fout = zero(F)
    for i=1:L
        shiftedi = get_shiftedi(i,L,shift)
        for ic=1:Nc
            for jc=1:Nc
                Fout.A[jc,ic,i] = F.A[jc,ic,shiftedi]
            end
        end
    end
    return F
end

function shift_Field(F::Field{Nc,L},shift) where {Nc,L}
    return Shifted_Field(F,shift)
end

function shift_Field(a::Shifted_Field{Nc,L,shift0},shift) where {Nc,L,shift0}
    return Shifted_Field(a.F,shift0 + shift)
end

これで、Fshift = shift_Field(F,shift)とすることで、格子点がshiftだけずれたものを得ることができます。あとは、このShifted_Field型同士の積や、Shifted_Field型とField型の積を定義:

function Base.:*(a::Shifted_Field{Nc,L,shift},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.F.A,:,:,i),view(b.F.A,:,:,i))
    end
    return c
end

function Base.:*(a::T,b::Shifted_Field{Nc,L,shift}) where {Nc,L,T <: Number,shift}
    c = zero(b.F)
    for i=1:L
        c.A[:,:,i] = a*view(b.F.A,:,:,i)
    end
    return c
end

function Base.:*(b::Shifted_Field{Nc,L,shift},a::T) where {Nc,L,T <: Number,shift}
    c = zero(b.F)
    for i=1:L
        c.A[:,:,i] = a*view(b.F.A,:,:,i)
    end
    return c
end

function Base.:+(a::Shifted_Field{Nc,L,shift1},b::Shifted_Field{Nc,L,shift2}) where {Nc,L,shift1,shift2}
    c = zero(a.F)
    for i=1:L
        c.A[:,:,i] = view(a.F.A,:,:,i) .+ view(b.F.A,:,:,i)
    end
    return c
end

function Base.:-(a::Shifted_Field{Nc,L,shift},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        c.A[:,:,i] = view(a.F.A,:,:,i) .- view(b.F.A,:,:,i)
    end
    return c
end

function Base.:*(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.F.A,:,:,i),view(b.A,:,:,i))
    end
    return c
end

function Base.:*(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
    c = zero(a)
    for i=1:L
        mul!(view(c.A,:,:,i),view(a.A,:,:,i),view(b.F.A,:,:,i))
    end
    return c
end

function Base.:+(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        c.A[:,:,i] = view(a.F.A,:,:,i) .+ view(b.A,:,:,i)
    end
    return c
end

function Base.:+(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .+ view(b.F.A,:,:,i)
    end
    return c
end

function Base.:-(a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        c.A[:,:,i] = view(a.F.A,:,:,i) .- view(b.A,:,:,i)
    end
    return c
end

function Base.:-(a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift}
    c = zero(a)
    for i=1:L
        c.A[:,:,i] = view(a.A,:,:,i) .- view(b.F.A,:,:,i)
    end
    return c
end

function Base.adjoint(a::Shifted_Field{Nc,L}) where {Nc,L,shift}
    c = zero(a.F)
    for i=1:L
        for ic=1:Nc
            for jc=1:Nc
                c.A[ic,jc,i] = a.F.A[jc,ic,i]
            end
        end
    end
    return c
end

しておけば、

function test2()
    Nc = 2
    L = 4
    a = random_field(Nc,L)
    ashift = shift_Field(a,1)
    println(tr(a*ashift*a))
end
test2()

のようにすれば、

\sum_{k=1} {\rm tr} (U_k U_{k+1} U_k)

という量でも簡単に計算できます。ここで、Juliaでの自動微分を使って行列で微分してみる
で定義したものはすでに定義してあるとしています。

rruleの定義

連鎖律とpullback

自動微分をZygoteのやってもらうために、rruleを定義しておきます。まず、関数$f$の$[U_k]_{ij}$の微分の連鎖律

\frac{\partial f}{\partial [U_k]_{ij}} = \sum_{l,m,n} \frac{\partial f}{\partial [C_n]_{lm}} \frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} 

が書けることに注意します。pullbackを

\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} \equiv \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} \frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} 

と定義すると、

\frac{\partial f}{\partial [U_k]_{ij}} = \sum_n \left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij}

となります。

シフトした量の微分

$p$だけシフトしたもの$C_n({ U }) = U_{n+p}$の微分を考えます。これは、

\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial [U_{n+p}]_{lm}}{\partial [U_k]_{ij}} = \delta_{n,k-p} \delta_{li} \delta_{mj}

となり、pullbackは

\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} = \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} \delta_{n,k-p} \delta_{li} \delta_{mj} =\delta_{n,k-p} \frac{\partial f}{\partial [C_{n}]_{ij}}

となります。ですので、

function ChainRulesCore.rrule(::typeof(shift_Field),a::Field{Nc,L},shift) where {Nc,L} 
    y = shift_Field(a,shift)
    function pullback(ybar)
        sbar = NoTangent()
        fabar = shift_Field(ybar,-shift)
        return sbar,fabar
    end
    return y, pullback
end

と実装してしまいましょう。

掛け算

$C_n({ U },{ P }) = U_n P_{n+p}$の場合、

\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial \sum_{a} [U_n]_{la} [P_{n+p}]_{am}}{\partial [U_k]_{ij}} = [P_{n+p}]_{jm} \delta_{n,k} \delta_{li}
\frac{\partial [C_n]_{lm}}{\partial [P_k]_{ij}} = [U_{n}]_{li} \delta_{n+p,k}  \delta_{mj}

となりますから、
pullbackは$U$に関する微分と$P$に関する微分の二種類が現れ、

\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} = 
\left( \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [P_{n+p}]_{jm} \delta_{n,k} \delta_{li} , \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [U_{n}]_{li} \delta_{n+p,k} \delta_{mj} \right) \\
= \left( \sum_{m} \frac{\partial f}{\partial [C_n]_{im}} [P_{n+p}]_{jm} \delta_{n,k} ,\sum_{l} \frac{\partial f}{\partial [C_n]_{lj}} [U_{n}]_{li} \delta_{n+p,k}   \right) \\
= \left( \delta_{n,k}  \left[ \frac{\partial f}{\partial [C_n]} P_{n+p}^T \right]_{ij} ,\delta_{n,k-p}  \left[ U_{n}^T \frac{\partial f}{\partial [C_n]}\right]_{ij}  \right) 

となります。

$C_n({ U },{ P }) = P_{n+p} U_n$の場合も同様に、

\frac{\partial [C_n]_{lm}}{\partial [U_k]_{ij}} = \frac{\partial \sum_{a} [P_{n+p}]_{la} [U_{n}]_{am}}{\partial [U_k]_{ij}} = [P_{n+p}]_{li} \delta_{n,k} \delta_{mj}  
\frac{\partial [C_n]_{lm}}{\partial [P_k]_{ij}} = [U_{n}]_{jm} \delta_{n+p,k} \delta_{il} 

となりますから、pullbackは

\left[ {\cal B}_{C_{n}}(\bar{C}_{n}) \right]_{ij} = 
\left( \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [P_{n+p}]_{li} \delta_{n,k}\delta_{mj}  , \sum_{l,m} \frac{\partial f}{\partial [C_n]_{lm}} [U_{n}]_{jm} \delta_{n+p,k} \delta_{il} \right) \\
= \left( \sum_{l} \frac{\partial f}{\partial [C_n]_{lj}} [P_{n+p}]_{li} \delta_{n,k}\  ,\sum_{m} \frac{\partial f}{\partial [C_n]_{im}} [U_{n}]_{jm} \delta_{n+p,k} \right) \\
= \left( \delta_{n,k}  \left[ P_{n+p}^T \frac{\partial f}{\partial [C_n]}  \right]_{ij} ,\delta_{n,k-p}  \left[ \frac{\partial f}{\partial [C_n]} U_{n}^T \right]_{ij}  \right) 

となります。

よって、それぞれの関数は

function ChainRulesCore.rrule(::typeof(*),a::Field{Nc,L},b::Shifted_Field{Nc,L,shift}) where {Nc,L,shift} 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = ybar*b'
        f2 = a'*ybar
        fbbar = shift_Field(f2,-shift)
        return sbar,fabar,fbbar
    end
    return y, pullback
end

function ChainRulesCore.rrule(::typeof(*),a::Shifted_Field{Nc,L,shift},b::Field{Nc,L}) where {Nc,L,shift} 
    y = a * b
    function pullback(ybar)
        sbar = NoTangent()
        fabar = a'*ybar
        f2 = ybar*b'
        fbbar = shift_Field(f2,-shift)
        return sbar,fabar,fbbar
    end
    return y, pullback
end


となります。

自動微分の実行

あとは自動微分を試します。数式を関数にすると

function calc_tr(x)
    xshift_1 = shift_Field(x,1)
    xshift_2 = shift_Field(x,2)
    xshift_3 = shift_Field(x,3)
    return tr(x*xshift_1*xshift_2*xshift_3)
end

となります。これを自動微分してみて、数値微分と比較してみます。

function test()
    Nc = 2
    L = 4
    a = random_field(Nc,L)
    b = random_field(Nc,L)
    c = a*b
    display(c)
    d = a+b
    display(d)
    println(tr(c))

    a_p = copy(a)
    delta = 1e-9
    a_p.A[1,1,1] += delta 

    fa = calc_tr(a)
    fad = calc_tr(a_p)
    df = (fad - fa)/delta
    println(df)

    ff = gradient(x -> calc_tr(x),a)
    println(ff[1].A[1,1,1])
    return
end
test()

ちゃんと微分できていることがわかると思います。 

4
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1