はじめに
この記事は時系列分析用のLatent ODE (潜在常微分方程式) の損失関数の設計についての備忘録です。数式メインですが実装寄りの内容になっています。
また、潜在空間は多変量正規分布かつ独立を仮定しています。
Latent ODEにおけるELBO
ELBO (Evidence Lower Bound) を使用したLatent ODEの損失関数の定義はVAEと同様に
\mathcal{L}_\mathrm{ELBO}=-\mathbb{E}_{q_\phi(z_0| x_{1:T})}\left[\sum_{t=1}^T\log p_\theta(x_t|z_t)\right]+\beta\mathrm{KL}\left(q_\phi(z_0|x_{1:T})\|p(z_0)\right)
$\beta$はKL誤差が損失関数に与える影響を制御するための比率でLatent系では徐々に大きくして潜在空間の正規化を強めていきます
添字の$\theta$はデコーダ、$\phi$はエンコーダのパラメータで決まることを意味しており、第1項は再構成誤差と呼ばれます。
なぜELBOが必要か
ELBOは尤度関数$\log p_\theta(x_{1:T})$を直接最大化することが難しいために導入された損失関数です。Latent ODEの対数尤度は初期潜在状態$z_0$での期待値をとって
$$\log p_\theta(x_{1:T})=\log\int p_\theta(x_{1:T}|z_0)p(z_0)dz_0$$
と表せます。
$p(z_0)\sim\mathcal{N}(0,1)$に関しては、正規分布の確率密度関数を計算することで求められますが、観測$x_{1:T}$は$z_t\mapsto x_t$で生成され、$z_t$は$z_0$を初期値に非線形に積分されるため、この式は解析的ではありません。
そこで、イェンセンの不等式から
\log p_\theta(x_{1:T})\geq \mathbb{E}_{q_\phi(z_0|x_{1:T})}\left[\log p_\theta(x_{1:T}|z_0)\right]-KL(q_\phi(z_0|x_{1:T})\mid\mid p(z_0))
が成り立ちます。よってELBOを最大化することで間接的に$\log p_\theta(x)$を最大化することができます。
分布の近似
次に、$\log p_\theta(x_{1:T}|z_0)$は各時刻$t$での確率$p_\theta(x_t|z_t)$の和となるため
\log p_\theta(x_{1:T}|z_0)=\log p(x_1|z_1)\cdots p(x_T|z_T)=\sum_t\log p(x_t|z_t)
条件付き期待値の条件が$z_0$から$z_t$に変わっているのは$z_0$がODEの数値積分で変化するためです。表記上は$z_0$でも問題ありません。
さらにELBOではモンテカルロ推定により期待値を近似するため、$z_0=\mu_\phi(x_{1:T})+\sigma^2_\phi({x_{1:T}})\odot\epsilon,\epsilon\sim\mathcal{N}(0,1)$の再パラメータ化を経て$L$本の初期状態のサンプルを取得し、
\mathbb{E}_{q_\phi}\left[\log p_\theta(x_{1:T}|z_0)\right]\approx\frac1L\sum_l\log p_\theta(x_{1:T}|z_0^{(l)})
のように期待値を近似します。$z_0^{(l)}$の上付きの添字は$l$番目のサンプルであることを意味します。
次に、$p(x_t|z_0)$と$p(z_0)$の値を具体的に計算する方法を考えます。観測$x_{1:T}$が与えられたときの初期潜在状態$z_0$の分布は、条件がない場合は標準正規分$p(z_0)\sim\mathcal{N}(0,1)$に従います。
さらに、$p_\theta(x_{1:T}|z_0)$も正規分布に従うと仮定します。このときのパラメータ$\mu_\theta(z_t),\sigma^2_\theta(z_t)$は$z_t$に対応する値でありデコーダの出力となります。
このときのパラメータは、予測値$\hat x_t$と分散を固定して$\mathcal{N}(\hat x_t,1)$とする場合や、直接的に予測値$\hat x_t$を求めるのではなく$\hat x_t$が従う平均$\mu_\theta(z_t)$と分散$\sigma^2_\theta(z_t)$をデコーダで推定した値を使用します。すると
p_\theta(x_t|z_t)\sim\mathcal{N}(\mu_\theta(x_{1:T}),\sigma^2_\theta(x_{1:T}))
が求められ$p_\theta(x_{1:T}|z_0)$を計算することができます。
KL誤差項の求め方
最後に、KL誤差の求め方ですが離散値のKLダイバージェンスの定義より
\mathrm{KL}(q_\phi(z_0\mid x_{1:T})\mid\mid p(z_0))=\int q_\phi(z_0\mid x_{1:T})\log\frac{q_\phi(z_0\mid x_{1:T})}{p(z_0)}dz_0=\mathbb{E}_{q_\phi}\left[\log q_\phi-\log p_\theta\right]
となります。
$q_\phi(z_0\mid x_{1:T})$はエンコーダから得られた、観測$x_{1:T}$が得られたときの$z_0$に関するパラメータ$\mu_\phi(x_{1:T})$と$\sigma_\phi^2(x_{1:T})$を利用して$q_\phi(z_0\mid x_{1:T})\sim\mathcal{N}(\mu_\phi(x_{1:T}),\sigma_\phi^2(x_{1:T}))$のような正規分布に従うと仮定し、値を求めます。
ここで、正規分布$\mathcal{N}(\mu,\sigma^2)$の確率密度関数$f(x)$の対数をとり
$$\log f(x)=\log\frac{1}{\sqrt{2\pi\sigma^2}}\exp\left[-\frac{(x-\mu)^2}{2\sigma^2}\right]=-\frac12\log2\pi\sigma^2-\frac{(x-\mu)^2}{2\sigma^2}$$
となり、KL誤差は
$$\mathrm{KL}(q_\phi(z_0\mid x_{1:T})\mid\mid\mathcal{N}(0,1))=\mathbb{E}\left[\frac12z^2-\frac12\log\sigma^2-\frac{(z-\mu)^2}{2\sigma^2}\right]$$
となります。また、$z^2$の項は分散を作るように変形すると
$$\mathbb{E}[z^2]=\mathbb{E}[(z-\mu)^2+2z\mu-\mu^2]=\sigma^2+\mu^2$$
となるため、以下のような式が得られ$z$の項が消えます。
\mathrm{KL}(q_\phi(z_0\mid x_{1:T})\mid\mid\mathcal{N}(0,1))=\frac12\mathbb{E}_{q_\theta}[\sigma^2+\mu^2-1-\log\sigma^2]
今回は独立な正規分布を仮定してるため、多変量正規分布でも潜在次元ごとに和を取るだけで良いです。
また、ここでいう平均と分散はエンコーダが出力した潜在状態$z_0$の条件付き分布の$\mu_\phi(x_{1:T})$と$\sigma^2_\phi(x_{1:T})$のことを指します。
IWAE (Importance Weight Autoencoder)
IWAE (Importance Weight Autoencoder) は$L$本のサンプルに対して$z_0$をそれぞれ$K$本サンプリングする手法です。
IWAEの損失関数は
\mathcal{L}_\mathrm{IWAE}=\mathbb{E}_{z_0^{(1:K)}\sim q_\phi(z_0|x_{1:T})}\left[\log\frac1K\sum_k\frac{p_\theta(x_{1:T},z_0^{(k)})}{q_\phi(z_0^{(k)}|x_{1:T})}\right]
で与えられます。IWAEを使用するメリットとして
\mathcal{L}_\mathrm{ELBO}\leq\mathcal{L}_\mathrm{IWAE}\leq\log p(x_{1:T})
が成立するため、ELBOよりも厳密な下界を得られる点があげられます。
ただし、勾配の分散は上昇するため学習の難易度が上がる可能性があります。
計算上のトリック
各項の求め方ですが、$q_\phi(z_0|x_{1:T})$はELBOと同様の方法で求めることができます。
そして、新たに文字$w_k$をおき$p_\theta(x_{1:T},z_0)$を条件付き確率で表し
w_k=\frac{p_\theta(x_{1:T},z_0)}{q_\phi(z_0|x_{1:T})}=\frac{p_\theta(x_{1:T}|z_0)p_\theta(z_0)}{q_\phi(z_0|x_{1:T})}
とします。ただし、$w_k$をそのまま求めると$p(x_{1:T}|z_0)=\prod_t p(x_t,z_0)\approx 0$と確率の積になり、数値的に安定しない(アンダーフローする)可能性があります。
そこで、対数をとり
\log w_k=\log\frac{p_\theta(x_{1:T}|z_0)p_\theta(z_0)}{q_\phi(z_0|x_{1:T})}=\log p_\theta(x_{1:T}|z_0)+\log p_\theta(z_0)-\log q_\phi(z_0|x_{1:T})
とします。そのため、$\mathcal{L}_\mathrm{IWAE}$は
\mathcal{L}_\mathrm{IWAE}=\mathbb{E}_{q_\phi}\left[\log\sum_k\exp(\log w_k)-\log K\right]
のように計算することで、数値を安定化させます。
この操作は、pytorchだとlogsumexp関数でできます。
IWAEとKLダイバージェンスの関係
IWAEはELBOと違ってKL誤差を暗黙的に含んでいます。$\mathcal{L}_\mathrm{IWAE}$を確率の対数の形を使って分解すると
\mathcal{L}_\mathrm{IWAE}=\mathbb{E}_{q_\phi}\left[\log\frac1K\sum_k\exp\left(\log p_\theta(x_{1:T}|z_0)-\log\frac{q_\phi(z_0|x_{1:T})}{p_\theta(z_0)}\right)\right]
となり、exp内の第2項にKLダイバージェンスのような項が作れます。これは厳密にはKLダイバージェンスではありませんが、学習時には潜在空間の正則化に関与します。
参考