Juliaの機械学習パッケージFluxでは、自動微分が使われています。この自動微分はZygote.jlというパッケージが使われています。Zygote.jlを使うと、関数を簡単に自動微分できます。この自動微分はどのような仕組みで実現されているのでしょうか?
Juliaの自動微分パッケージではDon't Unroll Adjoint: Differentiating SSA-Form Programsという論文にある技術が採用されています。
この記事の後半では、ここで説明した方法を用いて自分で自動微分を実装することにします。この前半の記事では、自動微分に使われている手法についての解説をしています。
環境
- Julia 1.7.3
Zygote.jlによる自動微分
まず、自動微分をやってみましょう。
関数として、
g(x) = \exp (3 x^2 + x)
を考えてみます。この微分は
g'(x) = (6 x + 1) \exp (3 x^2 + x)
ですね。これを自動微分で計算してみます。
using Zygote
function test()
g(x) = exp(3*x^2 + x)
gt(x) = (2*3*x + 1)*exp(3*x^2+x)
x = 0.1
println("$(g'(x)) $(gt(x))")
end
test()
とします。手で計算した微分と自動微分による結果を比較してみました。このようにg'(x)
で簡単に自動微分ができてしまいます。
Zygoteではもっと複雑な関数も自動微分できます。
例えば、
function sample_func(x)
a = sin(x)
b = 0.2 + a
c = sqrt(1+b)
return c
end
という関数も
x = 3
println(sample_func'(x))
で微分できます。
さらに、
function sample_func2(x)
a = [1+x 2*x^2
3 4+x]
e,v = eigen(a)
b = cos(e[1])
return b
end
のように途中に固有値計算が挟まっている関数でも、
dx = 1e-6
numerical_value = (sample_func2(x+dx) - sample_func2(x))/dx
println(sample_func2'(x),"\t",numerical_value )
微分できてしまいます。ここで、numerical_value
は数値微分です。
自動微分の仕組み
Zygote.jlの自動微分にはChainRulesCore.jlが使われています。このChainRulesCoreはどのように自動微分を実現しているのでしょうか? sample_func(x)
を例にして調べてみましょう。sample_func(x)
を数式で表すと
\begin{align}
a(x) = \sin(x) \\
b(a) = 0.2 + a \\
c(b) = \sqrt{1+b} \\
f(x) = c(b(a(x)))
\end{align}
となっています。計算したい量は$\partial f/\partial x$です。$f=c$ですから、$\partial c/\partial x$を計算します。$c$は間接的に$a$の関数ですから、連鎖律を使って、
\begin{align}
\frac{\partial c}{\partial x} = \frac{\partial a}{\partial x} \frac{\partial c}{\partial a}
\end{align}
となります。ここで、$\frac{\partial a}{\partial x} = \cos x$と簡単に計算できます。
$c$は$b$の関数でもあるので、
\begin{align}
\frac{\partial c}{\partial a} = \frac{\partial b}{\partial a} \frac{\partial c}{\partial b}
\end{align}
となりますが、$\frac{\partial b}{\partial a} = 1$と簡単に計算できます。さらに、
\begin{align}
\frac{\partial c}{\partial b} = \frac{\partial c}{\partial b} \frac{\partial c}{\partial c}
\end{align}
とすると、$\frac{\partial c}{\partial b} = (1/2)/\sqrt{1+b}$、$\frac{\partial c}{\partial c} =1$となりますから、計算できます。
つまり、
\begin{align}
\frac{\partial c}{\partial c} =1 \\
\frac{\partial c}{\partial b} = \frac{\partial c}{\partial b} \frac{\partial c}{\partial c} \\
\frac{\partial c}{\partial a} = \frac{\partial b}{\partial a} \frac{\partial c}{\partial b} \\
\frac{\partial c}{\partial x} = \frac{\partial a}{\partial x} \frac{\partial c}{\partial a} \\
\frac{\partial f}{\partial x} = \frac{\partial c}{\partial x}
\end{align}
という流れで計算が可能です。
元の関数の計算では
x \rightarrow a \rightarrow b \rightarrow c \rightarrow f
という流れで計算していましたが、微分は
\frac{\partial c}{\partial c} \rightarrow \frac{\partial c}{\partial b} \rightarrow \frac{\partial c}{\partial a} \rightarrow \frac{\partial c}{\partial x} \rightarrow \frac{\partial f}{\partial x}
という流れで計算できます。元の計算と変数の変化の流れが逆になっていますね。これをReverse modeと呼びます。
pullbackの仕組み
さて、微分の計算の流れはわかりました。これを実装する方法について考えてみましょう。$a$や$b$や$c$では一般化が難しそうですので、$f(x)$の計算を
\begin{align}
a_1 = f_1(x) \\
a_2 = f_2(a_1) \\
\vdots \\
a_i = f_{i}(a_{i-1}) \\
\vdots \\
a_{N-1} = f_{N-1}(a_{N-2})\\
a_N = f_N(a_{N-1})\\
f(x) = a_N
\end{align}
と書いておきましょう。このように書くと、微分の計算は
\begin{align}
\frac{\partial a_N}{\partial a_N} = 1 \\
\frac{\partial a_N}{\partial a_{N-1}} = \frac{\partial a_{N}}{\partial a_{N-1}} \frac{\partial a_N}{\partial a_N} = \frac{\partial f_N(x)}{\partial x} \Big{|}_{a_{N-1}} \frac{\partial a_N}{\partial a_N} \\
\frac{\partial a_{N}}{\partial a_{N-2}} = \frac{\partial a_{N-1}}{\partial a_{N-2}} \frac{\partial a_{N}}{\partial a_{N-1}} = \frac{\partial f_{N-1}(x)}{\partial x} \Big{|}_{a_{N-2}} \frac{\partial a_{N}}{\partial a_{N-1}} \\
\vdots \\
\frac{\partial a_{N}}{\partial a_{i-1}} = \frac{\partial f_{i}(x)}{\partial x} \Big{|}_{a_{i-1}} \frac{\partial a_{N}}{\partial a_{i}} \\
\vdots \\
\frac{\partial a_{N}}{\partial a_{1}} = \frac{\partial f_{2}(x)}{\partial x} \Big{|}_{a_{1}} \frac{\partial a_{N}}{\partial a_{2}} \\
\frac{\partial a_{N}}{\partial x} = \frac{\partial f_{1}(x)}{\partial x} \Big{|}_{x} \frac{\partial a_{N}}{\partial a_{1}}
\end{align}
となります。ポイントは、$\frac{\partial a_{N}}{\partial a_{i-1}} $の計算に、$\frac{\partial f_{i}(x)}{\partial x} $と$\frac{\partial a_{N}}{\partial a_{i}} $が必要ということです。前者は手で微分が計算できる関数、後者は伝播してやってきた係数です。後者はいつも$a_{N}$を微分する形になっていますから、$a_{N}$を$y$で偏微分したものを
\bar{y} \equiv \frac{\partial a_{N}}{\partial y}
と定義します。この時、
\begin{align}
\frac{\partial a_{N}}{\partial a_{i-1}} = \frac{\partial f_{i}(x)}{\partial x} \Big{|}_{a_{i-1}} \bar{a_{i}}
\end{align}
と書けます。
ここで、pullbackという概念を導入します。関数$y = f(x)$に対して、そのpullback関数$B_{y}(\bar{y})$を
\begin{align}
\bar{x} = \frac{\partial a_{N}}{\partial x} = \frac{\partial y}{\partial x} \frac{\partial a_{N}}{\partial y} = \frac{\partial y}{\partial x} \bar{y} \equiv B_{y}(\bar{y})
\end{align}
と定義します。例えば、$y = \sin(x)$であれば、
B_{y}(\bar{y}) = \cos(x) \bar{y}
となります。
計算の途中のpullback関数が全て与えられていれば、順番に計算していくことで微分が計算することができるわけです。それぞれの計算では微分が手で与えられているために、複雑な関数の微分を正確に計算できていることになります。
後半へ続く
後半では、このpullback関数を用いることで、実際に自動微分ができることを示します。