更新履歴:
2019/01/09 記事作成開始。
2019/01/24 Appendix部分完了。
ここでは2018年に注目を浴びた論文である、Neural Ordinary Differential Equations (Chen et al.)の解説を行います。主に式変形などの補完ができればと思っています。ニューラルネットワークや統計に関して勉強を始めて間もないため、間違い等ありましたら、コメントにてどしどしご指摘ください。
1. Introduction
$\Delta t$をレイヤー幅として、ニューラルネットワークのアウトプット結果${\bf z}_{t}$の変化を表す方程式は
\frac{{\bf z}_{t+\Delta t} - {\bf z}_{t}}{\Delta t}
=f({\bf z}(t), t, \theta)
のように書くことができます(論文(1)式)。今までは有限のレイヤー幅で考えていたため、このように離散化された形で書かれていました。ここでこのレイヤー幅が無限に小さいとしたら、この方程式の左辺は、微分の定義より
\frac{d{\bf z}(t)}{dt}
= f({\bf z}(t), t, \theta)
のように連続した値を持つ常微分方程式(ordinary differential equation: ODE)の形になります(論文(2)式)。ODEの形であれば、インプットレイヤー(初期値)の値がわかれば、アウトプットの値は適当なソルバーを使って求められるだろう、というのがこの論文の趣旨です。ニューラルネットワークの深さ1, 2, ...などに制約を受けることなく、例えば深さ2.5, 3.1といった部分でのアウトプットも得られるというわけです(論文図(1))。
ODEソルバーを用いてニューラルネットワークモデルを構築すると、以下のような利点があります。
- Memory efficiency(逆伝播を解く必要がないため、メモリの節約になる)
- Adaptive computation(いろいろな計算方法があるため、用途に合わせた計算が可能である)
- Parameter efficiency(パラメータが少なくて済む)
- Scalable and invertible normalizing flows(計算が簡単になる(?))
- Continuous time-series models(連続してアウトプットが出せる)
2. Reverse-mode automatic differentiation of ODE solutions
ここではadjoint method(Pontryagin et al., 1962)の方法を用いてODEを解く方法を示しています(詳細な式変形は近々掲載予定、ちなみに論文ではAppendix. Bに示されています)。
scalar-valued loss function $L()$を、論文(2)式をODEソルバーで解いた結果を用いて
L({\bf z}(t_1))
=L\left( \int_{t_0}^{t_1} f({\bf z}(t), t, \theta) dt \right)
=L({\rm ODE Solve}({\bf z}(t_0), f, t_0, t_1, \theta))
と書きます(論文(3)式)。$L$の挙動が知りたければ、パラメータ${\bf z}(t_0), t_0, t_1, \theta$に対する傾きがわかればよいのです。
まず最初に${\bf z}(t)$におけるlossの傾きを求めます。この量をadjointと呼び、
a(t)=-\frac{\partial L}{\partial {\bf z}(t)}
この$a(t)$はもう一本の微分方程式
\frac{da(t)}{dt} = -a(t)^{\mathrm{T}} \frac{\partial f({\bf z}, t, \theta)}{\partial {\bf z}}
に従うことが連鎖律から導くことができます(論文(4)式)。この導出をAppendixに示しました。
Appendix, 論文(4)式の導出
{\bf a}(t)
=- \frac{\partial L}{\partial {\bf z}(t)}
=- \frac{\partial L}{\partial {\bf z}(t+\Delta t)}\frac{\partial {\bf z}(t+\Delta t)}{\partial {\bf z}(t)}
={\bf a}(t+\Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}(t)}
途中
T_{\Delta t}={\bf z} (t+\Delta t)
としました。ここで
\frac{d{\bf z}}{dt} = f({\bf z}(t), t, \theta)
より
T_{\Delta t}
= {\bf z} (t+\Delta t)
= {\bf z} (t) + \int_t^{t+\Delta t} f dt
\simeq {\bf z} (t) +f \Delta t
よって
\begin{align}
\frac{d{\bf a}}{dt}
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t)}{\Delta t}
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}}}{\Delta t} \\
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+\Delta t)\frac{\partial}{\partial {\bf z}} ({\bf z} + f \Delta t)}{\Delta t}
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \left( {\bf I} + \frac{\partial f}{\partial {\bf z}} \Delta t \right)}{\Delta t} \\
&= -\lim_{\Delta t \to 0} {\bf a}(t+\Delta t) \frac{\partial f}{\partial {\bf z}}
=-{\bf a}(t) \frac{\partial f}{\partial {\bf z}}
\end{align}
となります。