juliaで自動微分(バックプロパゲーション)の感覚を掴みたい
こちらの本 ゼロから作るDeep Learning ❸ ―フレームワーク編の最初の方の実装部分を、juliaで実装してみました。
自動微分(バックプロパゲーション)
詳しい解説などは上記の本などを参考にしてください。
簡単にいうと、関数を計算するときに、その関数の微分も同時に計算することで、それらの情報をもとに合成関数の微分の計算を高速に行います。
sin, cosなど有名なものから構成される関数は、自動微分のライブラリーで既に実装されているようです。
ここでは裏でどのような実装になっているのかの感覚を掴むために、それらを手で実装してみます。
juliaで自動微分の実装部分は、Zygoteに関する記事を参考にさせて頂きました。
実装
step1からstep4からなります。
自動微分の実装だけみたい人は、step4を見てください。
step1
変数の定義をします。
#変数の型を決めます。
struct Variable
n::Number
end
function data(m::Variable)
return m.n
end
m = Variable(1.0)
ans = data(m)
@show ans #1
step2
関数を定義します。
struct Variable
data::Number
end
#抽象型を定義します。
abstract type Function end
#関数を定義します。
function (f::Function)(v::Variable)
x = v.data
y = forward(f, x) #juliaの内部で以下のforwardを実行する。
return Variable(y)
end
#以下、具体的な関数x^2, x^3, exp(x)を定義している。
struct Square <:Function #具体化することで抽象型Functionの範囲を狭めるイメージ
end
struct Cube <:Function #継承っぽいもの。サブタイプしている
end
struct Exp <:Function
end
#forward(f::Function, v::Variable) = forward(f, v.data)
forward(::Square, x::Number) = x^2 #メソッドを定義してる。 ::Square==f::Square. 第一引数は省略した。
forward(::Cube, x::Number) = x^3
forward(::Exp, x::Number) = exp(x)
s = Square()
v = Variable(3)
@show forward(s, v.data) # 9
@show s(v) # Variable(9)
c = Cube()
@show forward(c, v.data) # 27
@show c(v) # Variable(27)
e = Exp()
@show forward(s, v.data)
@show c(v)
step3
一応、数値微分を実装します。step2で実装した関数を使って合成関数z = $(exp(x^2))^2$の数値微分を計算します。
struct Variable
data::Number
end
abstract type Function end
function (f::Function)(v::Variable)
x = v.data
y = forward(f, x)
return Variable(y)
end
struct Square <:Function #継承っぽいもの。サブタイプしている
end
function numerical_dif(f, x , eps)
x1 = Variable(x.data + eps)
x2 = Variable(x.data - eps)
y1 = forward(f, x1.data)
y2 = forward(f, x2.data)
return (y1 - y2) / 2eps
end
forward(::Square, x::Number) = (exp(x^2))^2 #forwardのメソッドを定義してる。 ::Square==f::Square. 第一引数は省略した。
s = Square()
v = Variable(0.5)
eps=0.001
s_ = numerical_dif(s,v,eps)
@show s_ #3.2974513345919165
step4
自動微分を実装します。
入力xとして合成関数zを以下のように計算していきます。
$a = f(x) = x^2$
$y = g(a) = \exp(a)$
$z = h(y) = y^2$
ここで、 $z = h(g(f(x))) = (\exp(x^2))^2$ を自動微分するのが本記事のゴールです。
自動微分のポイントは、関数の微分(dz/dy, dy/da, da/dx)の式の形を覚えておくことです。
ここで、プログラム中のpullback関数は、Zygoteに関する記事のgradient and pullbackの節を参照にしました。
本実装では、合成関数zに関するpullbackの実装において以下の3つの値を計算することで、合成関数zのxに対する微分を計算してます。
\begin{aligned}
\bar{x} &=\frac{d z}{d x}=\frac{d z}{d a} \frac{d a}{d x}=\bar{a} \frac{d a}{d x}=B_{x}(\bar{a}) \\
\bar{a} &=\frac{d z}{d a}=\frac{d z}{d y} \frac{d y}{d a}=B_{a}(\bar{y})\\
\bar{y} &=\frac{d z}{d y}=\frac{d z}{d z} \frac{d z}{d y}=B_{y}(\bar{z})
\end{aligned}
ここで、上式の$B$ が実装中でbackwardと呼んでいる関数です。
まず、$\bar{z}(=1)$ を最初のbackward関数 $B_{y} $の引数に代入することで、$\bar{y}$ を計算します。
同じように、$\bar{a}$, $\bar{x}$ を計算します。
struct Variable
data::Number
end
Base.:*(v1::Variable, v2::Variable) = Variable(v1.data * v2.data)
Base.:*(c::Number, v1::Variable) = Variable(v1.data * c) #redundancy?
Base.:*(v1::Variable, c::Number) = Variable(v1.data * c) #redundancy?
abstract type Function end
function (f::Function)(v::Variable)
x = v.data
y = forward(f, x)
output = Variable(y)
return output
end
struct Exp<:Function
end
struct Square<:Function
end
struct z <:Function
end
forward(::Square, x::Number) = x^2
forward(::Exp, x::Number) = exp(x)
forward(::z, x::Number) = (exp(x^2))^2 #これを自動微分するのがゴール
#関数の値とそのbackwardを返す関数
function pullback(f::Exp, input::Variable) # f = exp(x)
y = f(input)
function backward(gy::Variable)
return f(input) * gy
end
return (y, backward)
end
function pullback(f::Square, input::Variable) #f = x^2
y = f(input)
function backward(gy::Variable)
return 2 * input * gy
end
return (y, backward)
end
#合成関数zとそのbackwardの実装
function pullback(z::Function, input::Variable)
f = Square()
g = Exp()
h = Square()
a, back_for_f = pullback(f::Square, input)
b, back_for_g = pullback(g::Exp, a)
z, back_for_h = pullback(h::Square, b)
function back_for_z(z_)
y_bar = back_for_h(z_)
a_bar = back_for_g(y_bar)
x_bar = back_for_f(a_bar)
return x_bar
end
return (z, back_for_z)
end
function mygradient(y, input::Variable) #例えばyがSquare()なら2*xがかえる
_, back = pullback(y, input)
return back(Variable(1.0))
end
x = Variable(0.5)
D = z() #z = (f ∘ g ∘ h)()
res2 = mygradient(D, x)
println(res2.data) #step3で実装した数値微分の値と同じになる
まとめ
合成関数に対する自動微分を実装し理解が深まりました。
なお、今回は合成関数を定義しましたが、juliaの複合関数を使えばもっと簡単に実装できる方法があるかもしれません。