Edited at

論文紹介: Neural Ordinary Differential Equations

更新履歴:

2019/01/09 記事作成開始。

2019/01/24 Appendix部分完了。

ここでは2018年に注目を浴びた論文である、Neural Ordinary Differential Equations (Chen et al.)の解説を行います。主に式変形などの補完ができればと思っています。ニューラルネットワークや統計に関して勉強を始めて間もないため、間違い等ありましたら、コメントにてどしどしご指摘ください。


1. Introduction

$\Delta t$をレイヤー幅として、ニューラルネットワークのアウトプット結果${\bf z}_{t}$の変化を表す方程式は

\frac{{\bf z}_{t+\Delta t} - {\bf z}_{t}}{\Delta t}

=f({\bf z}(t), t, \theta)

のように書くことができます(論文(1)式)。今までは有限のレイヤー幅で考えていたため、このように離散化された形で書かれていました。ここでこのレイヤー幅が無限に小さいとしたら、この方程式の左辺は、微分の定義より

\frac{d{\bf z}(t)}{dt} 

= f({\bf z}(t), t, \theta)

のように連続した値を持つ常微分方程式(ordinary differential equation: ODE)の形になります(論文(2)式)。ODEの形であれば、インプットレイヤー(初期値)の値がわかれば、アウトプットの値は適当なソルバーを使って求められるだろう、というのがこの論文の趣旨です。ニューラルネットワークの深さ1, 2, ...などに制約を受けることなく、例えば深さ2.5, 3.1といった部分でのアウトプットも得られるというわけです(論文図(1))。

ODEソルバーを用いてニューラルネットワークモデルを構築すると、以下のような利点があります。


  • Memory efficiency(逆伝播を解く必要がないため、メモリの節約になる)

  • Adaptive computation(いろいろな計算方法があるため、用途に合わせた計算が可能である)

  • Parameter efficiency(パラメータが少なくて済む)

  • Scalable and invertible normalizing flows(計算が簡単になる(?))

  • Continuous time-series models(連続してアウトプットが出せる)


2. Reverse-mode automatic differentiation of ODE solutions

ここではadjoint method(Pontryagin et al., 1962)の方法を用いてODEを解く方法を示しています(詳細な式変形は近々掲載予定、ちなみに論文ではAppendix. Bに示されています)。

scalar-valued loss function $L()$を、論文(2)式をODEソルバーで解いた結果を用いて

L({\bf z}(t_1))

=L\left( \int_{t_0}^{t_1} f({\bf z}(t), t, \theta) dt \right)
=L({\rm ODE Solve}({\bf z}(t_0), f, t_0, t_1, \theta))

と書きます(論文(3)式)。$L$の挙動が知りたければ、パラメータ${\bf z}(t_0), t_0, t_1, \theta$に対する傾きがわかればよいのです。

まず最初に${\bf z}(t)$におけるlossの傾きを求めます。この量をadjointと呼び、

a(t)=-\frac{\partial L}{\partial {\bf z}(t)}

この$a(t)$はもう一本の微分方程式

\frac{da(t)}{dt} = -a(t)^{\mathrm{T}} \frac{\partial f({\bf z}, t, \theta)}{\partial {\bf z}}

に従うことが連鎖律から導くことができます(論文(4)式)。この導出をAppendixに示しました。


Appendix, 論文(4)式の導出

{\bf a}(t)

=- \frac{\partial L}{\partial {\bf z}(t)}
=- \frac{\partial L}{\partial {\bf z}(t+\Delta t)}\frac{\partial {\bf z}(t+\Delta t)}{\partial {\bf z}(t)}
={\bf a}(t+\Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}(t)}

途中

T_{\Delta t}={\bf z} (t+\Delta t) 

としました。ここで

\frac{d{\bf z}}{dt} = f({\bf z}(t), t, \theta)

より

T_{\Delta t} 

= {\bf z} (t+\Delta t)
= {\bf z} (t) + \int_t^{t+\Delta t} f dt
\simeq {\bf z} (t) +f \Delta t

よって

\begin{align}

\frac{d{\bf a}}{dt}
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t)}{\Delta t}
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}}}{\Delta t} \\
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+\Delta t)\frac{\partial}{\partial {\bf z}} ({\bf z} + f \Delta t)}{\Delta t}
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \left( {\bf I} + \frac{\partial f}{\partial {\bf z}} \Delta t \right)}{\Delta t} \\
&= -\lim_{\Delta t \to 0} {\bf a}(t+\Delta t) \frac{\partial f}{\partial {\bf z}}
=-{\bf a}(t) \frac{\partial f}{\partial {\bf z}}
\end{align}

となります。


参考記事

AINOU記事

SlideShare