5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

モーメントマッチングから仮定密度フィルタリング、期待値伝播法(1)

Last updated at Posted at 2021-07-31

「ベイズ深層学習」4章 近似ベイズ推論より
現在の近似推論はMCMCや変分推論が主流であり、期待値伝播法はベイズ深層学習やPRMLでも紹介されているにも関わらず日本語の記事はわずかしかなかったのでこの記事を書こうと思いました。

分量が多くなってしまったので期待値伝播法は次回書きたいと思います。

モーメントマッチング

まず、仮定密度フィルタリングや期待値伝播法の基礎となるモーメントマッチングについて解説します。
ある複雑な確率分布$p(\boldsymbol{z})$が存在し、それを指数型分布族$q(\boldsymbol{z})$で近似したいとします。
$$q(\boldsymbol{z};\boldsymbol{\eta}) = h(\boldsymbol{z})exp(\boldsymbol{\eta} ^\top \boldsymbol{t}(\boldsymbol{z}) - a(\boldsymbol{\eta}))$$
モーメントマッチングの目的は、変分推論などと同様にKLダイバージェンスを最小化することですが、$q(z)$は指数型分布族と仮定しているため以下のようにして最小化を実現できます。

\begin{align}
D_{KL}[p(\boldsymbol{z})||q(\boldsymbol{z};\boldsymbol{\eta}))] & = \int_{-\infty}^\infty p(\boldsymbol{z}) log\frac{p(\boldsymbol{z})}{q(\boldsymbol{z};\boldsymbol{\eta}))} dz \\
& =  - \mathbb{E}_p[ln \ q(\boldsymbol{z};\boldsymbol{\eta})] + \mathbb{E}_p[ln \ p(\boldsymbol{z})] \\
\end{align}

自然パラメータ$\boldsymbol{\eta}$に関して勾配をとって0とすると(途中計算省略)
$$\mathbb{E}_q[\boldsymbol{t}(\boldsymbol{z})] = \mathbb{E}_p[\boldsymbol{t(\boldsymbol{z})}]$$
この$\boldsymbol{t(\boldsymbol{z})}$は十分統計量です。

※変分推論とは元の分布$p$と近似分布$q$の位置が逆になっていることに注意してください。$D_{KL}[p(z)||q(z)] $の最小化と$D_{KL}[q(z)||p(z)]$の最小化で得られる近似分布は異なってきます。(KLダイバージェンスは距離ではないため)
( 参考: https://www.slideshare.net/ssuser8672d7/ss-147555894)

仮定密度フィルタリング

↑の説明で、私は最初
「p(z)のモーメントが計算可能ということは分布の形状も分かっていてわざわざ近似する必要ないのでは?」などとアホなことを考えていましたが、共役性の成り立たない事後分布ではモーメントの計算は可能だが分布の形状はわからないケースが存在します。

そこで仮定密度フィルタリングでは、共役性の成り立たないモデルにおいて、事後分布を事前分布と同じ分布で近似するという考えのもと、事後分布のパラメータを逐次的に学習させていきます。

  • 共役性が成り立っている場合
    データ$\boldsymbol{x}_1 $を観測したとき、パラメータ$\boldsymbol{\theta}$の事後分布は
    $$
    p(\boldsymbol{\theta}|\boldsymbol{x}_1) \propto p(\boldsymbol{x}_1|\boldsymbol{\theta})p(\boldsymbol{\theta})
    $$ それからデータ$\boldsymbol{x}_2$を観測したとすると事後分布は解析的に計算することが可能です。
    $$
    p(\boldsymbol{\theta}|\boldsymbol{x}_1,\boldsymbol{x}_2) \propto p(\boldsymbol{x}_2|\boldsymbol{\theta})p(\boldsymbol{\theta}|\boldsymbol{x}_1)
    $$
  • 共役性が成り立たない場合
    解析的な計算はできません。
    そのため近似事後分布を事前分布と同じ指数型分布として、データ$\boldsymbol{x_1}$を観測したときの近似事後分布$q_1(\boldsymbol{\theta})$をモーメントマッチングで求めます。
    $$
    q_1(\boldsymbol{\theta}) \overset{MM} \approx r_{1}(\boldsymbol{\theta}) = \frac{1}{Z}_1p(\boldsymbol{x} _i |\boldsymbol{\theta})p(\boldsymbol{\theta})
    $$

こうして得られた近似事後分布を次の事前分布とする。これにより、分布の形状を保ったまま逐次的な更新が可能です。
$$
q_{i+1}(\boldsymbol{\theta}) \overset{MM}{\approx} r_{i+1}(\boldsymbol{\theta}) = \frac{1}{Z_{i+1}}p(\boldsymbol{x}_{i+1}|\boldsymbol{\theta})q_i(\boldsymbol{\theta})
$$

事前分布、つまりは近似事後分布を正規分布としたときのモーメントマッチングの更新式を導出します。
尤度は$\ f_i(\boldsymbol{\theta})$とします。

$q_{i+1}(\boldsymbol{\theta})$の正規化定数$\ Z_{i+1}$は、

\begin{align}
Z_{i+1} & =  \int f_{i+1}(\boldsymbol{\theta})q_i(\boldsymbol{\theta})d\boldsymbol{\theta} \\
& =  \int f_{i+1}(\boldsymbol{\theta}) \frac{1}{(2\pi|V_i|)^{d/2}} \exp{\left(-\frac{1}{2}\left(\boldsymbol{\theta}-\boldsymbol{\mu}_i\right)^\top V_i^{-1}\left( \boldsymbol{\theta}-\boldsymbol{\mu}_i\right)\right)d\boldsymbol{\theta}} \\
\end{align}

これの対数をとったものを $\boldsymbol{\mu}_{i}$で微分すると

\begin{align}
 \frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} ln\ Z_{i+1}& = \frac{1}{Z_{i+1}} \int f_{i+1}(\boldsymbol{\theta}) N(\boldsymbol{\theta}\ |\ \boldsymbol{\mu}_{i},V_i) V_i^{-1}( \boldsymbol{\theta}-\boldsymbol{\mu}_i)d\boldsymbol{\theta} \\
& = V_i^{-1}(\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}] -  \boldsymbol{\mu}_i)
\end{align}

よって、分布$\ r_{i+1}(\boldsymbol{\theta})$の一次のモーメントは

\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}] = \boldsymbol{\mu}_i + V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} ln\ Z_{i+1}

更新後の近似事後分布$q_{i+1}(\boldsymbol{\theta})$の平均パラメータ$\ \mu_{i+1} $はそのまま

\begin{align}
\mu_{i+1} &= \mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}] \\ 
&=  \boldsymbol{\mu}_i + V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} ln\ Z_{i+1}
\end{align}

となります。

同様にして、共分散行列$V_i$で微分します。
これで正しいのか微妙ですが、行列式の行列微分と二次形式の行列微分の公式を利用し

\begin{align}
 \frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}& = \frac{1}{Z_{i+1}} \int f_{i+1}(\boldsymbol{\theta}) N(\boldsymbol{\theta}\ |\ \boldsymbol{\mu}_{i},V_i)
\cdot -\frac{1}{2}\left (V_i^{-1} - V_i^{-1}\left( \boldsymbol{\theta}-\boldsymbol{\mu}_i\right )\left( \boldsymbol{\theta}-\boldsymbol{\mu}_i\right )^\top V_i^{-1}\right)d\boldsymbol{\theta} \\
& =  \frac{1}{Z_{i+1}} \int r_{i+1}(\boldsymbol{\theta})
\cdot -\frac{1}{2}\left (V_i^{-1} - V_i^{-1}\left( \boldsymbol{\theta}-\boldsymbol{\mu}_i\right )\left( \boldsymbol{\theta}-\boldsymbol{\mu}_i\right )^\top V_i^{-1}\right)d\boldsymbol{\theta} \\
& =  -\frac{1}{2}V_i^{-1} +\frac{1}{2}V_i^{-1}\left(\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}\boldsymbol{\theta}^\top] -  \mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}] \boldsymbol{\mu}_i^\top
-\boldsymbol{\mu}_i\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}]^\top +\boldsymbol{\mu}_i\boldsymbol{\mu}_i^\top \right)V_i^{-1}
\end{align}

よって、分布$\ r_{i+1}(\boldsymbol{\theta})$の二次のモーメントは

\begin{align}
\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}\boldsymbol{\theta}^\top] &=
2V_i\left(\frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}\right)V_i + V_i + \mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}] \boldsymbol{\mu}_i^\top
+\boldsymbol{\mu}_i\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}]^\top - \boldsymbol{\mu}_i\boldsymbol{\mu}_i^\top \\ 
\end{align}

となり、更新後の近似事後分布$q_{i+1}(\boldsymbol{\theta})$の共分散行列$ V_{i+1} $は先程求めた一次のモーメントの値も利用し

\begin{align}
V_{i+1} &= \mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}\boldsymbol{\theta}^\top] - \mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}]\mathbb{E}_{r_{i+1}}[\boldsymbol{\theta}]^\top\\ 
&= 2V_i\left(\frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}\right)V_i + V_i - 
V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}
\left(V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}\right)^\top \\
& = 
V_i + 2V_i\left(\frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}\right)V_i - 
V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}
\left(\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}\right)^\top V_i^\top \\
&= 
V_i - V_i\left( \left(\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1} \right)
\left(\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}\right)^\top - 
2\frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}\right)V_i
\end{align}

と計算が可能です。これは近似分布としてガウス分布を選んだときの一般的な結果ですが、$Z_{i+1} = \int f_{i+1}(\boldsymbol{\theta})q_i(\boldsymbol{\theta})d\boldsymbol{\theta}$が計算可能である必要があります。

プロビット回帰への適用

ここで、仮定密度フィルタリングの適用事例を紹介します。
プロビット回帰は2値分類を予測するためのモデルで、入力を$\boldsymbol{x} \in \mathbb{R}^d$、出力を$y \in \{-1,1\}$として以下のような尤度関数で表されます。
(本では入力は一次元ですが多次元に拡張してあります)

\begin{align}
p(\boldsymbol{Y}|\boldsymbol{X},\boldsymbol{w}) & = \prod_{n=1}^N p(y_n|\boldsymbol{x}_n,\boldsymbol{w}) \\
& = \prod_{n=1}^N \Phi(y_n\boldsymbol{w}^\top\boldsymbol{x}_n)
\end{align}

$\Phi$は標準正規分布の累積密度関数です。

パラメータの事前分布は
$$
p(\boldsymbol{w}) = \mathcal{N}(\boldsymbol{w}| \boldsymbol{0},v_0I)
$$
とします。
そして近似事後分布も事前分布と同様に正規分布とします。

\begin{align}
q(\boldsymbol{w}) &= \mathcal{N}(\boldsymbol{w}|\boldsymbol{\mu},V)\\ 

\end{align}

尤度と事前分布が決まったため,近似事後分布を正規分布としてモーメントマッチングをデータの数だけ繰り返すことで最終的な事後分布の近似を得ることができます。

これから、前節で求めたモーメントマッチングの更新式を利用し、プロビットモデルでの更新式を導出していきます。

パラメータの更新には正規化定数が解析的に求まる必要がありましたが、この場合正規化定数は、以下のようになるらしいです。(誰かここの計算過程教えて下さい)
参照 https://tminka.github.io/papers/ep/minka-ep-quickref.pdf

\begin{align}
Z_{i+1} &= \int  \Phi(y_{i+1}\boldsymbol{w}^\top\boldsymbol{x}_{i+1})\mathcal{N}(\boldsymbol{w}|\boldsymbol{\mu_{i}},V_i)d\boldsymbol{w}\\
&= \Phi(a_{i+1})\\
\\
a_{i+1} &= \frac{y_{i+1}\boldsymbol{\mu}_{i}^{\top} \boldsymbol{x}_{i+1}}{\sqrt{1 + \boldsymbol{x}_{i+1}^{\top}V_{i}\boldsymbol{x}_{i+1}}}
\end{align}

$ln  Z_{i+1} $の微分を導出していきます。

\begin{align}
 \frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} ln\ Z_{i+1} &= 
\frac{1}{Z_{i+1}}\frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} \int_{-\infty}^{a_{i+1}}  \mathcal{N}(0,1)d\boldsymbol{w}\\
&=\frac{1}{Z_{i+1}}\mathcal{N}(a_{i+1}|0,1)\frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} a_{i+1}\\
&= \frac{1}{Z_{i+1}}\mathcal{N}(a_{i+1}|0,1)\frac{y_{i+1} \boldsymbol{x}_{i+1}}{\sqrt{1 + \boldsymbol{x}_{i+1}^{\top}V_{i}\boldsymbol{x}_{i+1}}}
\end{align}
\begin{align}
 \frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1} &= 
\frac{1}{Z_{i+1}}\frac{\partial{}}{\partial{V_i}} \int_{-\infty}^{a_{i+1}} \frac{1}{\sqrt{2\pi}}\exp \left(-\frac{x^2}{2}\right)dx \\
&= \frac{1}{Z_{i+1}}\mathcal{N}(a_{i+1}|0,1)\frac{\partial{}}{\partial{V_i}} a_{i+1}\\
&= \frac{1}{Z_{i+1}}\mathcal{N}(a_{i+1}|0,1)a_{i+1}\left(-\frac{1}{2}\frac{\boldsymbol{x}_{i+1}\boldsymbol{x}_{i+1}^\top}{1 + \boldsymbol{x}_{i+1}^{\top}V_{i}\boldsymbol{x}_{i+1}}\right)\\
\end{align}

ようやく必要なものがそろったので、あとは前述のパラメータの更新式に従っていきます。

\boldsymbol{\mu}_{i+1} =  \boldsymbol{\mu}_i + V_i\frac{\partial{}}{\partial{\boldsymbol{\mu}_i}} ln\ Z_{i+1} \\
V_{i+1} =
V_i - V_i\left(\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}
\left(\frac{\partial{}}{\partial{\boldsymbol{\mu}_{i}}} ln\ Z_{i+1}\right)^\top - 
2\frac{\partial{}}{\partial{V_i}} ln\ Z_{i+1}\right)V_i

ここからはpythonでの実装とともに紹介します。
異なる正規分布に従うデータを生成します。

def generate_data(neg_size,pos_size):
    mean_neg = np.array([2, 2])
    cov_neg = np.array([[1, 0], [0, 1]])

    mean_pos = np.array([-2, -1])
    cov_pos = np.array([[1, 0], [0, 1]])
    #定数項を追加
    #ベクトルは記事の表記に合わせ列ベクトルとする
    x_neg = np.array([np.hstack(([1],d)).reshape(3,1) for d in np.random.multivariate_normal(mean_neg, cov_neg, size=neg_size)])
    x_pos = np.array([np.hstack(([1],d)).reshape(3,1) for d in np.random.multivariate_normal(mean_pos, cov_pos, size=pos_size)])

    X = np.concatenate([x_neg,x_pos])
    Y = np.concatenate([-1*np.ones(neg_size),np.ones(pos_size)])
    return X,Y

事前分布は平均0,共分散行列は単位行列の定数倍で初期化します。

class ADFforProbitRegression():
    def __init__(self):
        self.mu = np.array([0,0,0]).reshape(3,1)
        self.v = np.array([[100,0,0],[0,100,0],[0,0,100]])
        self.mus= []
        self.vs= []

平均と共分散行列の更新部分。
表記を記事と近づけているためやっていることは簡単にわかるかと思います。

    def fit(self,X,Y):
        for x,y in zip(X,Y):
            self.update_param(x,y)

    def update_param(self,x,y):
        self.update_mu(x,y)
        self.update_v(x,y)

    def update_mu(self,x,y):
        a = self.calc_a(x,y)
        Z = norm.cdf(a)
        dlnZ_du = self.calc_dlnZ_du(a,Z,x,y)
        
        self.mu = self.mu + self.v @ dlnZ_du

    def update_v(self,x,y):
        a = self.calc_a(x,y)
        Z = norm.cdf(a)
        dlnZ_du = self.calc_dlnZ_du(a,Z,x,y)
        dlnZ_dv = self.calc_dlnZ_dv(a,Z,x,y)
        
        self.v = self.v- self.v @ (dlnZ_du @ dlnZ_du.T - 2*dlnZ_dv) @ self.v

    def calc_dlnZ_du(self,a,Z,x,y):
        return (1/Z) * norm.pdf(a) * (y*x / np.sqrt(1+(x.T @ self.v @ x)[0][0]))
    def calc_dlnZ_dv(self,a,Z,x,y):
        return (1/Z) * norm.pdf(a) * a* (-(x @ x.T))/(2*(1+(x.T @ self.v @ x)[0][0]))
    def calc_a(self,x,y):
        a =y*(self.mu.T @ x)[0][0] / np.sqrt(1+(x.T @ self.v @ x)[0][0])
        return a

全てのデータを学習すると以下のような近似事後分布が得られます。

q(\boldsymbol{w}) = \mathcal{N}(\boldsymbol{w}|\boldsymbol{\mu^*},V^*)

試しにいくつか$ \boldsymbol{w}$をサンプリングし、分類確率$\Phi(\boldsymbol{w}^\top \boldsymbol{x})$が0.5となる境界線を描画しました。

random_w.png

最尤推定などでは得られない、不確実性が伴う回帰結果が得られているように見えます。

次に予測分布を考えます。
新しいデータ$\ \boldsymbol{x}_{new}$が与えられたとき、出力が1となる確率は

\begin{align}
p(y=1|\boldsymbol{x}_{new},\boldsymbol{\mu}^* ,V^*) &= \int p(y=1|\boldsymbol{x}_{new},\boldsymbol{w})p(\boldsymbol{w}|\boldsymbol{\mu}^* ,V^*)d\boldsymbol{w} \\
&= \int \Phi(\boldsymbol{w}^\top \boldsymbol{x}_{new}) \mathcal{N}(\boldsymbol{w}|\boldsymbol{\mu}^* ,V^*)d\boldsymbol{w} \\
&= \Phi \left(\frac{\boldsymbol{\mu}^{*\top} \boldsymbol{x}_{new}}{\sqrt{1 + \boldsymbol{x}_{new}^{\top}V^*\boldsymbol{x}_{new}}}\right)
\end{align}

ここで、

\Phi \left(\frac{\boldsymbol{\mu}^{*\top} \boldsymbol{x}_{new}}{\sqrt{1 + \boldsymbol{x}_{new}^{\top}V^*\boldsymbol{x}_{new}}}\right) = p^*

となるような$\ \boldsymbol{x}_{new}$を考えると、確率$ p^* $で1と分類できるような境界線を考えることができ、その境界線を描画しました。境界線といっても厳密に確率が$ p^* $である点ではなくだいたい$p^*$である領域となっていますが。
黒い領域は確率が50%、黄色の領域は95%、緑の領域は99%でどちらかのカテゴリに分類されます。

result_adf.png

よく分類できているかのように見えますが、仮定密度フィルタリングでは結果がデータの学習順序に依存してしまうという欠点があります。
それを確認するため、外れ値を加えたデータに対し、外れ値を最後に学習したときの学習状況の変化をず図にしました。
(ここでの$\boldsymbol{w}$はその学習時点での平均$\ \boldsymbol{\mu}$です。)

result_adf.gif

最後に外れ値を学習するため、境界線が大きく変化していることがわかります。
学習データをシャッフルし、外れ値の学習が最後にならないようにすると、
result_adf.gif

途中境界線が外れ値を学習したことで大きく変化しますがその後また戻っています。
このように学習順序に依存してしまい、結果が不安定になることが仮定密度フィルタリングのデメリットです。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?