Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

Neural ODEの紹介

はじめに

これは、NeuralODEについて整理するためのものです。ですので、間違いは多々あると思うので、本気でNeuralODEを学びたいという方は、元論文や他のプロの方々が解説してる記事をお読みになる事をおすすめします。

イントロ

2020-11-12 (1).png
上の図がこれまでのRNN、下の図が今回のNeuralODEを組み合わせたモデルでのフィットと予測図です。緑の線が本来の時系列データ、緑の点が時系列データを欠損あり&不等間隔にした観測値、青の線がモデルがフィットした線、赤の線が予測値です。これまでのRNNは、概形をなんとなく学習出来てるかな程度ですが、下のモデルの方が上手く学習出来ているのが分かります。フィット部分に関しては、正解の線とぴったりです。

NeuralODEの前にResNet

ResNetとは、深いニューラルネットワークでもうまく学習が行くように取り入れられたものです。
第N層目の隠れ層の状態を$h_N$とすると、ResNetでは、

$$h_N=f(θ,h_{N-1})+h_{N-1}\tag{1}$$

と書けます。つまり、前の層の隠れ状態に関数を作用させたものと、隠れ状態そのものを入力することで、前の層との差を学習させています(ResNetの名前の由来)。一般的なResNetでは、入力$X$に対して何層かのConv層を挟み$y=F(X)$を出力します。そこに、入力$X$を足した$H(X)=F(X)+X$を出力します。あとは、この$H(X)$と正解ラベルから成る損失関数を普通のニューラルネットワークのように学習させるだけです。ResNetは詳しい日本語解説が多くあるので、そちらをご参照ください。
NeuralODEは、このResNetから着想を得ています。

NeuralODE

順番に説明していきます。

隠れ層

NeuralODEでは、これまで離散的に考えていた層の間を微小に幅にして連続値として考えます。先程の式(1)の離散値Nを実数tで書き、隠れ状態をhからtに書き直すと、$$z_t=f(θ,z_{t-1})+z_{t-1}\tag{2}$$となります。これを式変形すると$$z_t-z_{t-1}=f(θ,z_{t-1})\tag{3}$$層の幅をΔtとすると、$$z_t-z_{t-1}\approx f(θ,z_{t-1})×Δt\tag{4}$$となります。ここで層の幅を微小幅dtにしてやると、$$\frac{dz(t)}{dt}=f(θ,z(t),t)\tag{4}$$と書けます。つまり、この常微分方程式が隠れ層のダイナミクスを表現しており、順伝播では、この常微分方程式を解いていくことになります。

順伝播

Neural ODEでは、一般的なニューラルネットワークのように各層で異なる関数$f_N$を持つのではなく、$f$のみで順伝播を行います。つまり、式(4)をtの初期値$t_0$から$t_1$の間で積分するだけです。$$z(t_1)=\int_{t_0}^{t_1}f(θ,z(t),t)dt\tag{5}$$
よって、$z(t_1)$と損失関数$L(z(t_1))$が求まりました。

逆伝播

逆伝播では、誤差を後ろから前に伝播し、最初の層のパラメータの勾配を求めることがゴールになります。従って、求めたい勾配は$\frac{∂L}{∂θ}$です。
では、これをどうやって求めるのかっていうのが今回の大きなポイントで、計算にはadjoint methodという方法を用います。

adjoint method

adjoint methodについて説明します。
まず損失関数の微分を$$a(t)=\frac{∂L}{∂z(t)}\tag{6}$$と置きます。この$a(t)$はadjoint(随伴行列、随伴ベクトル)といいます。$a(t)$は、次のadjoint方程式を満たします。$$\frac{∂a(t)}{∂t}=-a(t)^{T}\frac{∂f(θ,z(t),t)}{∂z}...(Tは転置の意)\tag{7}$$
この常微分方程式は、$a(t_1)$を初期値として、$t_1$から$t_0$方向に解いていきます。$a(t_1)$は式(3)より$a(t_1)=\frac{∂L}{∂z(t_1)}$とすぐ求まります。ここまでがadjoint methodです。
同様の事を、パラメータ$θ$についても行います。
逆伝播の目的の$\frac{∂L}{∂\theta}$は、$$\frac{∂L}{∂\theta}=a_{θ}(t)=-\int_{t_1}^{t_0}a(t)^{T}\frac{∂f(θ,z(t),t)}{∂\theta}dt\tag{8}$$で計算できます。ここで、$\frac{∂f(θ,z(t),t)}{∂\theta}$は各時刻におけるヤコビアンです。初期値$a_θ(t_1)$についてですが、パラメータを最適化した結果、損失関数$L$はθに依存しなくなる、つまり、$\frac{∂L}{∂θ}=0$となるので、初期値は$a_θ(t_1)=\frac{∂L}{∂θ(t_1)}=0$として計算をします。
以上で、計算アルゴリズムの説明を終了します。正直、式(6)と式(7)が、なぜこうなるかは分かってないのですが、とりあえずこういうもんだと思っておきます。

4つの利点

これまでややこしい計算式を書いてきましたが、これによって得られる4つの利点を簡単に紹介します。この4つは元論文で述べられていたものです。

Memory efficiency

これまでのニューラルネットワークでは、隠れ層の各層それぞれで異なる関数$f_n(\theta)$が存在してましたが、NeuralODEはたった一つの関数$f(\theta,z(t),t)$で隠れ層を支配するので、モデルのパラメータが少なくなります。結果的にメモリ使用量を抑えることが出来ます。

Adaptive computation

NeuralODEは常微分方程式の数値解を求めるものでした。モデルの精度を自由に調整できます。従って、エッジコンピュータなどの比較的小さい端末でも必要に応じて軽く動かすことが出来ます。

Scalable and invertible normalizing flows

Normalizing flowsとは、変分推論の手法の1つで、標準正規分布(ガウス分布)に非線形変換を行うことで複雑な確率分布を表現するというものです。自分は、この辺はちんぷんかんぷんなので詳しい内容は分からないのですが、離散的な式(1)がNormalizing Flowsで用いられるようです。そして、ヤコビアンの行列式を計算するのですが、隠れ層または入力データzの次元の3乗の計算コストになるそうです。しかし、層を離散的から連続的に変換することで正規化定数の計算が単純化されるようです。詳しくは、論文をお読みになってください。

Continuous time-series models

従来のRNNは、入力される系列データは等間隔で欠損値がないのが前提でした。なぜなら、隠れ層が離散的だったので等間隔でしか情報を扱えませんでした。
NeuralODEは、隠れ層を連続化してるので、入力される時系列データが不等間隔でサンプリングされたものや欠損値があっても、最初の画像のように上手く学習が出来ます。

補足

一応、不等間隔でも学習させる方法はあって、1つ目が時間差Δtと観測値の二つを入力する方法で$h_N=RNN(h_{N-1},Δt,x_N)$す。2つ目がRNNDecayモデルで、時間差Δtの大きさによって前の隠れ層との相関を考慮したものです。$h_N=RNN(h_{N-1}・\exp(-τΔt),x_N)$(ただし、τは減衰率)とすることで、直前のデータが強く影響し、大幅な時間差であれば忘れてしまおう、てな感じです。なお、予測タスクにおいて、後者のモデルは前者のモデルより良い性能が出なかったそうです。
自分は、これまで時間$t$そのものと、観測値を学習させていました。これらは本質的には一緒なのかが気になります。自分の直感的なイメージは、時間差より時間そのものを学習させた方が系列データであることを強く認識させられると思ってるのですが。時間があるときに調べて試してみようと思います。

この記事を書くことで、だいぶNeuralODEについて整理が出来ました。次は、これを実装していきたいと思います。GitHubに、元論文の著者さんがRNNと組み合わせたサンプルコードもあるので、それを参考に自分のデータに適用したいと思ってます。

参考サイト

修正史

yu_og
理系卒研生。プログラミング初心者。知識定着のために記事投稿。 免責事項:私が発する情報が正確な情報になるよう細心の注意を払っておりますが、当情報において正確性等について保証するものではなく、一切の責任を負いません。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away