LoginSignup
25
19

More than 5 years have passed since last update.

更新履歴:
2019/01/09 記事作成開始。
2019/01/24 Appendix部分完了。

ここでは2018年に注目を浴びた論文である、Neural Ordinary Differential Equations (Chen et al.)の解説を行います。主に式変形などの補完ができればと思っています。ニューラルネットワークや統計に関して勉強を始めて間もないため、間違い等ありましたら、コメントにてどしどしご指摘ください。

1. Introduction

$\Delta t$をレイヤー幅として、ニューラルネットワークのアウトプット結果${\bf z}_{t}$の変化を表す方程式は

\frac{{\bf z}_{t+\Delta t} - {\bf z}_{t}}{\Delta t}
=f({\bf z}(t), t, \theta)

のように書くことができます(論文(1)式)。今までは有限のレイヤー幅で考えていたため、このように離散化された形で書かれていました。ここでこのレイヤー幅が無限に小さいとしたら、この方程式の左辺は、微分の定義より

\frac{d{\bf z}(t)}{dt} 
= f({\bf z}(t), t, \theta)

のように連続した値を持つ常微分方程式(ordinary differential equation: ODE)の形になります(論文(2)式)。ODEの形であれば、インプットレイヤー(初期値)の値がわかれば、アウトプットの値は適当なソルバーを使って求められるだろう、というのがこの論文の趣旨です。ニューラルネットワークの深さ1, 2, ...などに制約を受けることなく、例えば深さ2.5, 3.1といった部分でのアウトプットも得られるというわけです(論文図(1))。

ODEソルバーを用いてニューラルネットワークモデルを構築すると、以下のような利点があります。

  • Memory efficiency(逆伝播を解く必要がないため、メモリの節約になる)
  • Adaptive computation(いろいろな計算方法があるため、用途に合わせた計算が可能である)
  • Parameter efficiency(パラメータが少なくて済む)
  • Scalable and invertible normalizing flows(計算が簡単になる(?))
  • Continuous time-series models(連続してアウトプットが出せる)

2. Reverse-mode automatic differentiation of ODE solutions

ここではadjoint method(Pontryagin et al., 1962)の方法を用いてODEを解く方法を示しています(詳細な式変形は近々掲載予定、ちなみに論文ではAppendix. Bに示されています)。

scalar-valued loss function $L()$を、論文(2)式をODEソルバーで解いた結果を用いて

L({\bf z}(t_1))
=L\left( \int_{t_0}^{t_1} f({\bf z}(t), t, \theta) dt \right)
=L({\rm ODE Solve}({\bf z}(t_0), f, t_0, t_1, \theta))

と書きます(論文(3)式)。$L$の挙動が知りたければ、パラメータ${\bf z}(t_0), t_0, t_1, \theta$に対する傾きがわかればよいのです。

まず最初に${\bf z}(t)$におけるlossの傾きを求めます。この量をadjointと呼び、

a(t)=-\frac{\partial L}{\partial {\bf z}(t)}

この$a(t)$はもう一本の微分方程式

\frac{da(t)}{dt} = -a(t)^{\mathrm{T}} \frac{\partial f({\bf z}, t, \theta)}{\partial {\bf z}}

に従うことが連鎖律から導くことができます(論文(4)式)。この導出をAppendixに示しました。

Appendix, 論文(4)式の導出

{\bf a}(t)
=- \frac{\partial L}{\partial {\bf z}(t)}
=- \frac{\partial L}{\partial {\bf z}(t+\Delta t)}\frac{\partial {\bf z}(t+\Delta t)}{\partial {\bf z}(t)}
={\bf a}(t+\Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}(t)}

途中

T_{\Delta t}={\bf z} (t+\Delta t) 

としました。ここで

\frac{d{\bf z}}{dt} = f({\bf z}(t), t, \theta)

より

T_{\Delta t} 
= {\bf z} (t+\Delta t) 
= {\bf z} (t) + \int_t^{t+\Delta t} f dt 
\simeq {\bf z} (t) +f \Delta t

よって

\begin{align}
\frac{d{\bf a}}{dt}
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t)}{\Delta t} 
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \frac{\partial T_{\Delta t}}{\partial {\bf z}}}{\Delta t} \\
&=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+\Delta t)\frac{\partial}{\partial {\bf z}} ({\bf z} + f \Delta t)}{\Delta t} 
=\lim_{\Delta t \to 0} \frac{{\bf a}(t+\Delta t) -{\bf a}(t+ \Delta t) \left( {\bf I} + \frac{\partial f}{\partial {\bf z}} \Delta t \right)}{\Delta t} \\
&= -\lim_{\Delta t \to 0} {\bf a}(t+\Delta t) \frac{\partial f}{\partial {\bf z}}
=-{\bf a}(t) \frac{\partial f}{\partial {\bf z}}
\end{align}

となります。

参考記事

AINOU記事
SlideShare

25
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
19