LoginSignup
3
4

More than 3 years have passed since last update.

『ベイズ深層学習』勉強会資料 5章の1

Last updated at Posted at 2019-12-08

記事の概要

『ベイズ深層学習』の輪読会で使用する資料です。
記事では『ベイズ深層学習』を持っていることを前提として、一部の数式や説明を省略しています。同書を持っている人が自習する時の参考になることを目的としています。

記事中の誤りは全て私の理解不足に起因し、輪読会参加者および本の著者、引用記事の作成者とは無関係になります。
誤りや不明点がございましたら、コメント欄などでご指摘いただけると助かります。

ベイズニューラルネットワーク

ベイズニューラルネットワークモデル

準伝播型ニューラルネットワークをベイズ化する。
以下でニューラルネットワークのパラメータの事後分布を求める。

入力を以下とし、

\begin{eqnarray}
\mathbf{X} = \{ \mathbf{x}_1, \cdots, \mathbf{x}_n \}
\end{eqnarray}

観測データを以下とした場合、

\begin{eqnarray}
\mathbf{Y} = \{ \mathbf{y}_1, \cdots, \mathbf{y}_n \}
\end{eqnarray}

パラメータ$\mathbf{W}$の同時分布は以下になる。観測モデルにはガウス分布を用いる。

\begin{eqnarray}
p(\mathbf{Y}, \mathbf{W} | \mathbf{X}) 
&=& p(\mathbf{W}) \prod_{n=1}^N p(\mathbf{y}_n | \mathbf{x}_n, \mathbf{W}) \\
&=& p(\mathbf{W}) \prod_{n=1}^N \mathcal{N}(\mathbf{y}_n | f(\mathbf{x}_n ; \mathbf{W}) , \sigma_y^2 \mathbf{I})
\end{eqnarray}

ここで$\sigma_y^2$はノイズパラメータ、そして$f(\mathbf{x}_n ; \mathbf{W})$がニューラルネットワークである。
例えば2層のニューラルネットワークの場合、$f(\mathbf{x} ; \mathbf{W})$の$d$次元目の出力は(典型的な2層ニューラルネットワークである)以下になる。

\begin{eqnarray}
f_d(\mathbf{x}_n ; \mathbf{W}) = \sum_{h_1=1}^{H_1} \omega_{d,h_1}^{(2)} \phi \biggl( \sum_{h_0=1}^{H_0} \omega_{h_1,h_0}^{(1)} x_{n,h_0} \biggl)
\end{eqnarray}

$h_0$:入力層、$h_1$:隠れ層、$d$:出力層

パラメータの事後分布の計算にはパラメータの事前分布が必要なので、$W$は以下のガウス分布に従うとする。

\begin{eqnarray}
p(\omega) = \mathcal{N}(\omega | 0, \sigma_{\omega}^2)
\end{eqnarray}

ラプラス近似による学習

ラプラス近似とは、事後分布をMAP推定値のガウス分布で近似する推論手法である。
ラプラス近似されたパラメータの事後分布$q(\mathbf{W})$を以下とする。

\begin{eqnarray}
q(\mathbf{W})
&=& \mathcal{N}(\mathbf{W} | \mathbf{W}_{MAP}, \{  \Lambda(\mathbf{W}_{MAP}) \}^{-1})
\end{eqnarray}
\begin{eqnarray}
\Lambda(\mathbf{W}) \equiv - \nabla_{\mathbf{W}}^2 \ln p(\mathbf{W} | \mathbf{Y},\mathbf{X})
\end{eqnarray}

ここで$p(\mathbf{W} | \mathbf{Y},\mathbf{X})$の値が最大となる点を$\mathbf{W}_{MAP}$としている。

事後分布の計算

step1:MAP推定値を求める

$p(\mathbf{W} | \mathbf{Y},\mathbf{X})$が最大値を取るパラメータ$\mathbf{W}_{MAP}$を求めるには、学習によりパラメータの最適解を求めればいい。
具体的には、本書p114の(5.6)のパラメータ更新式を繰り返し計算して最適解に収束させる。
(5.6)の2項目は(5.8)より、通常の微分や誤差逆伝播法で計算できる。
$\ln p(\mathbf{W} | \mathbf{Y},\mathbf{X})$は本書P78の3.4.2の(3.100)で与えられているので、これをパラメータについて偏微分すれば(5.8)が求まる。

また(5.8)より$\Lambda(\mathbf{W})$は以下になる
ここで(2.42)のヘッセ行列の定義と(2.11)の正則化項の定義を用いた。

\begin{eqnarray}
\Lambda(\mathbf{W}) 
&=& \frac{1}{\sigma_y^2} \nabla_{\mathbf{W}}^2 E(\mathbf{W}) + \frac{1}{\sigma_{\omega}^2} \nabla_{\mathbf{W}}^2 \Omega_{L2}(\mathbf{W}) \\
&=& \frac{1}{\sigma_y^2} \mathbf{H} + \frac{1}{\sigma_{\omega}^2} \mathbf{I}
\end{eqnarray}

実際の計算においては、ヘッセ行列の近似(2.58)を用いるとよいかもしれない。

step2:予測分布を近似する

パラメータの近似事後分布を用いて、入力$ \mathbf{x} _* $ に対する出力 $ y _* $の予測分布を近似する。

\begin{eqnarray}
p(y_*| \mathbf{x}_*, \mathbf{Y},\mathbf{X}) 
&\approx&
\int p(y_* | \mathbf{x}_*, \mathbf{X}) q(\mathbf{W})  d \mathbf{W}
\end{eqnarray}

$p(y_* | \mathbf{x}_*, \mathbf{X})$ はニューラルネットワークを含むので解析的に計算できない。

そこでニューラルネットワーク$f(\mathbf{x}_* ; \mathbf{W})$ を$\mathbf{W} _{MAP}$ について線形近似する。

\begin{eqnarray}
f(\mathbf{x}_* ; \mathbf{W})
&\approx&
f(\mathbf{x}_* ; \mathbf{W}_{MAP}) + \mathbf{g}^T  (\mathbf{W} - \mathbf{W}_{MAP})
\end{eqnarray}
\begin{eqnarray}
\mathbf{g}
&\equiv&
\nabla_{\mathbf{W}} f(\mathbf{x}_* ; \mathbf{W})|_{\mathbf{W} = \mathbf{W}_{MAP}}
\end{eqnarray}

これは非線形関数を含まないので解析的に計算できる。

(5.14)の導出には以下の公式を用いる。
公式の証明はPRML(パターン認識と機械学習)の2.3.3を参照されたい。

$\mathbf{x}$の周辺ガウス分布が以下で与えられるとする

\begin{eqnarray}
p(\mathbf{x}) = \mathcal{N}(\mathbf{x} | \mu, \Lambda^{-1})
\end{eqnarray}

また、$\mathbf{y}$の条件付きガウス分布が以下で与えられるとする

\begin{eqnarray}
p(\mathbf{y} | \mathbf{x}) = \mathcal{N}(\mathbf{y} | \mathbf{A} \mathbf{x} + \mathbf{b} , L^{-1})
\end{eqnarray}

その時、$y$の周辺分布は以下となる。

\begin{eqnarray}
p(\mathbf{y}) 
&=& 
\int p(\mathbf{y} | \mathbf{x})  p(\mathbf{x}) d \mathbf{x} \\
&=& \mathcal{N}(\mathbf{y} | \mathbf{A} \mu + \mathbf{b} , L^{-1} + \mathbf{A} \Lambda^{-1} \mathbf{A}^T)
\end{eqnarray}

予測分布の近似は、観測モデルにガウス分布を用いれば以下のようになるので、

\begin{eqnarray}
p(y_*| \mathbf{x}_*, \mathbf{Y},\mathbf{X}) 
&\approx&
\int p(y_* | \mathbf{x}_*, \mathbf{X}) q(\mathbf{W})  d \mathbf{W} \\
&\approx&
\int 
\mathcal{N}(y_*| f(\mathbf{x}_* ; \mathbf{W}_{MAP}) + \mathbf{g}^T  (\mathbf{W} - \mathbf{W}_{MAP}), \sigma_y^2)
\mathcal{N}(\mathbf{W} | \mathbf{W}_{MAP}, \{  \Lambda(\mathbf{W}_{MAP}) \}^{-1})  d \mathbf{W} \\
\end{eqnarray}

上記の公式を適用すると、以下の置き換えにより

$\mathbf{x} \to \mathbf{W}$

$\mu \to \mathbf{W}_{MAP}$

$\Lambda^{-1} \to { \Lambda (\mathbf{W}_{MAP}) } ^{-1}$

$\mathbf{y} \to y_*$

$\mathbf{A} \to \mathbf{g}^T$

$\mathbf{b} \to f(\mathbf{x} _* ; \mathbf{W} _{MAP}) - \mathbf{g}^T \mathbf{W} _{MAP}$

$L^{-1} \to \sigma_y^2$

(5.14)が求まる。

\begin{eqnarray}
&& \mathcal{N}(y_* | \mathbf{g}^T \mathbf{W}_{MAP} +  f(\mathbf{x} _* ; \mathbf{W} _{MAP}) - \mathbf{g}^T \mathbf{W} _{MAP} , \sigma_y^2 + \mathbf{g}^T \{ \Lambda (\mathbf{W}_{MAP}) \}^{-1} \mathbf{g}) \\
&=&
\mathcal{N}(y_* | f(\mathbf{x} _* ; \mathbf{W} _{MAP}), \sigma_y^2 + \mathbf{g}^T \{ \Lambda (\mathbf{W}_{MAP}) \}^{-1} \mathbf{g})
\end{eqnarray}

ハミルトンモンテカルロ法による学習

(5.16)のポテンシャルエネルギーを用いれば、(4.1.16)と同様にパラメータ$\mathbf{W}$の更新式が得られる。

\begin{eqnarray}
\mathbf{W}_{new}
&=& \mathbf{W}_{old} - \frac{\epsilon^2}{2} \nabla_{\mathbf{W}} \mathcal{U} + \epsilon \mathbf{p} \\
&=& \mathbf{W}_{old} + \frac{\epsilon^2}{2} \Bigl(\nabla_{\mathbf{W}} \ln p(\mathbf{Y} | \mathbf{X}, \mathbf{W}) + \nabla_{\mathbf{W}} \ln p(\mathbf{W}) \Bigl) + \epsilon \mathbf{p}
\end{eqnarray}

ここで運動量は

\begin{eqnarray}
\mathbf{p} \sim \mathcal{N}(0, \mathbf{I})
\end{eqnarray}

とおく。

今後の予定

5.1.3.1および5.2以降も勉強会で行ったが、数式だけを追いかけても理解したという実感は得られなかった。
実際に実装してみないと分からないと感じた。

5.1の内容についてはpytorchとPyroで実装されている方がいた。

【ベイズ深層学習】Pyroでベイズニューラルネットワークモデルの近似ベイズ推論の実装

pyroの使用方法を知らないので、以下を参照して勉強したい。
Pyro 0.3.0 : Examples : ベイジアン回帰 – 推論アルゴリズム (Part 2)
確率的プログラミング言語Pyroと変分ベイズ推論の基本

5.2以降の内容については以下の実装などを参照して勉強したい。
https://github.com/Ivan1248/deep-learning-uncertainty
Uncertainty in Deep Learning

関連記事

『ベイズ深層学習』勉強会資料 4章の1
『ベイズ深層学習』勉強会資料 4章の2.1~2.4
期待値伝播法

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4