0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

juliaで自動微分の感覚を掴む

Last updated at Posted at 2022-01-25

juliaで自動微分(バックプロパゲーション)の感覚を掴みたい

こちらの本 ゼロから作るDeep Learning ❸ ―フレームワーク編の最初の方の実装部分を、juliaで実装してみました。

自動微分(バックプロパゲーション)

詳しい解説などは上記の本などを参考にしてください。

簡単にいうと、関数を計算するときに、その関数の微分も同時に計算することで、それらの情報をもとに合成関数の微分の計算を高速に行います。

sin, cosなど有名なものから構成される関数は、自動微分のライブラリーで既に実装されているようです。
ここでは裏でどのような実装になっているのかの感覚を掴むために、それらを手で実装してみます。

juliaで自動微分の実装部分は、Zygoteに関する記事を参考にさせて頂きました。

実装

step1からstep4からなります。
自動微分の実装だけみたい人は、step4を見てください。

step1

変数の定義をします。

#変数の型を決めます。
struct Variable
    n::Number
end

function data(m::Variable)
    return m.n
end

m = Variable(1.0)
ans = data(m)
@show ans #1

step2

関数を定義します。

struct Variable
    data::Number
end

#抽象型を定義します。
abstract type Function end 

#関数を定義します。
function (f::Function)(v::Variable)
    x = v.data
    y = forward(f, x) #juliaの内部で以下のforwardを実行する。
    return Variable(y)
end

#以下、具体的な関数x^2, x^3, exp(x)を定義している。
struct Square <:Function #具体化することで抽象型Functionの範囲を狭めるイメージ
end

struct Cube <:Function #継承っぽいもの。サブタイプしている
end

struct Exp <:Function
end

#forward(f::Function, v::Variable) = forward(f, v.data)
forward(::Square, x::Number) = x^2 #メソッドを定義してる。 ::Square==f::Square. 第一引数は省略した。 
forward(::Cube, x::Number) = x^3 
forward(::Exp, x::Number) = exp(x)

s = Square()
v = Variable(3)
@show forward(s, v.data) # 9
@show s(v) # Variable(9)

c = Cube()
@show forward(c, v.data) # 27
@show c(v) # Variable(27)

e = Exp()
@show forward(s, v.data)
@show c(v)

step3

一応、数値微分を実装します。step2で実装した関数を使って合成関数z = $(exp(x^2))^2$の数値微分を計算します。

struct Variable
    data::Number
end

abstract type Function end 

function (f::Function)(v::Variable)
    x = v.data
    y = forward(f, x)
    return Variable(y)
end

struct Square <:Function #継承っぽいもの。サブタイプしている
end


function numerical_dif(f, x , eps)
    x1 = Variable(x.data + eps)
    x2 = Variable(x.data - eps)
    y1 = forward(f, x1.data)
    y2 = forward(f, x2.data)
    return (y1 - y2) / 2eps
end

forward(::Square, x::Number) = (exp(x^2))^2  #forwardのメソッドを定義してる。 ::Square==f::Square. 第一引数は省略した。 

s = Square()
v = Variable(0.5)
eps=0.001
s_ = numerical_dif(s,v,eps)
@show s_ #3.2974513345919165

step4

自動微分を実装します。

入力xとして合成関数zを以下のように計算していきます。
$a = f(x) = x^2$

$y = g(a) = \exp(a)$

$z = h(y) = y^2$

ここで、 $z = h(g(f(x))) = (\exp(x^2))^2$ を自動微分するのが本記事のゴールです。

自動微分のポイントは、関数の微分(dz/dy, dy/da, da/dx)の式の形を覚えておくことです。

ここで、プログラム中のpullback関数は、Zygoteに関する記事のgradient and pullbackの節を参照にしました。

本実装では、合成関数zに関するpullbackの実装において以下の3つの値を計算することで、合成関数zのxに対する微分を計算してます。

\begin{aligned}
\bar{x} &=\frac{d z}{d x}=\frac{d z}{d a} \frac{d a}{d x}=\bar{a} \frac{d a}{d x}=B_{x}(\bar{a}) \\
\bar{a} &=\frac{d z}{d a}=\frac{d z}{d y} \frac{d y}{d a}=B_{a}(\bar{y})\\
\bar{y} &=\frac{d z}{d y}=\frac{d z}{d z} \frac{d z}{d y}=B_{y}(\bar{z}) 
\end{aligned} 

ここで、上式の$B$ が実装中でbackwardと呼んでいる関数です。

まず、$\bar{z}(=1)$ を最初のbackward関数 $B_{y} $の引数に代入することで、$\bar{y}$ を計算します。 

同じように、$\bar{a}$, $\bar{x}$ を計算します。

struct Variable 
    data::Number
end

Base.:*(v1::Variable, v2::Variable) = Variable(v1.data * v2.data)
Base.:*(c::Number, v1::Variable) = Variable(v1.data * c) #redundancy? 
Base.:*(v1::Variable, c::Number) = Variable(v1.data * c) #redundancy?

abstract type Function end

function (f::Function)(v::Variable)
    x = v.data
    y = forward(f, x)
    output = Variable(y)
    return output
end

struct Exp<:Function
end
struct Square<:Function
end

struct z <:Function
end


forward(::Square, x::Number) = x^2
forward(::Exp, x::Number) = exp(x) 
forward(::z, x::Number) = (exp(x^2))^2 #これを自動微分するのがゴール

#関数の値とそのbackwardを返す関数
function pullback(f::Exp, input::Variable) # f = exp(x)
    y = f(input)
    function backward(gy::Variable)
        return f(input)  * gy 
    end
    return (y, backward)
end

function pullback(f::Square, input::Variable) #f = x^2
    y = f(input)
    function backward(gy::Variable)
        return  2 * input * gy                                 
    end
    return (y, backward)
end


#合成関数zとそのbackwardの実装
function pullback(z::Function, input::Variable)
    f = Square()
    g = Exp()
    h = Square()
    a, back_for_f = pullback(f::Square, input)
    b, back_for_g = pullback(g::Exp, a)
    z, back_for_h = pullback(h::Square, b)
   
    function back_for_z(z_) 
        y_bar = back_for_h(z_)
        a_bar = back_for_g(y_bar)
        x_bar = back_for_f(a_bar)
        return x_bar
    end
    return (z, back_for_z)
end

function mygradient(y, input::Variable) #例えばyがSquare()なら2*xがかえる
    _, back = pullback(y, input)
    return back(Variable(1.0)) 
end


x = Variable(0.5)
D = z() #z = (f ∘ g ∘ h)()
res2 = mygradient(D, x)
println(res2.data) #step3で実装した数値微分の値と同じになる

まとめ

合成関数に対する自動微分を実装し理解が深まりました。
なお、今回は合成関数を定義しましたが、juliaの複合関数を使えばもっと簡単に実装できる方法があるかもしれません。

参考

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?