1
2

More than 1 year has passed since last update.

ニューラルネットワークの予測の不確実性(stochastic variational inference・概要編)

Last updated at Posted at 2023-07-17

はじめに

stochastic variational inference で予測の不確実性を算出する方法の概要を説明します。

事前準備(変分ベイズ)

前提知識として必要なベイズ推論と変分推論について説明します。

ベイズ推論

通常のニューラルネットワークでは、出力は変数ですが、ベイズ推論ではラベル$y$の分布$p(y|x)$を考えます。
分布が尖った形になっている場合は予測の不確実性が低く、裾が広い場合は不確実性が高いことになります。

ラベルの分布は、パラメータの分布$p(\theta)$を考え、積分で$\theta$を消去することで求めます。
$$
p(y|x) = \int p(y|x, \theta)p(\theta) {\rm d}\theta
$$

変分推論

パラメータの事後分布$p(\theta|X,Y)$を、計算しやすい分布$q(\theta)$で近似することを考えます。
近似分布$q(\theta)$は、$p(\theta|X,Y)$とのKLダイバージェンスが小さければ、よい近似となります。
$$
D_{KL}[q(\theta)||p(\theta|X,Y)] = \int q(\theta)\ln\frac{q(\theta)}{p(\theta|X,Y)} {\rm d}\theta
$$

しかし、直接最小化することは難しいため、データの周辺分布 $\ln p(X,Y)$ とその下界(ELBO)$\mathcal{L}(\xi)$ を考えます。
これらと、上記KLダイバージェンスには、下記の関係があることが知られています。
$$
\ln p(X,Y) = \mathcal{L}(\xi) + D_{KL}[q(\theta;\xi)||p(\theta|X,Y)]
$$
ここで、$\xi$は$q(\theta)$の形を決めるパラメータで、変分推論では$\xi$を最適化します。

この関係性を用いて、KLダイバージェンスを最小化する問題を、ELBOを最大化する問題に置き換えます。
ELBOは下記で定義されます。
$$
\mathcal{L}(\xi) = \int q(\theta;\xi)\ln \frac{p(X,Y,\theta)}{q(\theta;\xi)} {\rm d}\theta
$$

同時確率を条件付き確率に書き換えると、$p(X,Y,\theta)=p(Y|X,\theta)p(X|\theta)p(\theta)$となり、$X$と$\theta$は独立のため $p(X,Y,\theta)=p(Y|X,\theta)p(X)p(\theta)$ と書くことができます。
これをELBOに代入すると下記のようになります。

\mathcal{L}(\xi) = \int q(\theta;\xi) \ln p(Y|X, \theta) {\rm d}\theta + \int q(\theta;\xi) \ln p(X) {\rm d}\theta + \int q(\theta;\xi) \ln \frac{p(\theta)}{q(\theta;\xi)} {\rm d}\theta
= \mathbb{E}_{q(\theta;\xi)}[\ln p(Y|X, \theta)] + \ln p(X) - D_{KL}[q(\theta;\xi)||p(\theta)] 

$\ln p(X)$は定数なので、ELBOを最大化するために、下記の負の対数尤度と$q(\theta)$とパラメータの事前分布とのKLダイバージェンスの和を最小化すればよいことが分かります。

\mathbb{E}_{q(\theta;\xi)}[-\ln p(Y|X, \theta)] + D_{KL}[q(\theta;\xi)||p(\theta)]

Stochastic Variational Inference

負の対数尤度の期待値を最小化するために、各バッチで$q(\theta;\xi)$からパラメータをサンプリングして、そのパラメータでの負の対数尤度を最小化します。
$q(\theta;\xi)$は勾配降下法で最適化できる必要があるため、$q(\theta;\xi)$の実装では下記のような方法が用いられます。

reparametrization trick

$q(\theta;\xi)$を、平均$\mu$、標準偏差$\sigma^2$の正規分布$\mathcal{N}(\mu,\sigma^2)$からサンプリングする場合、まず、平均0、標準偏差1の正規分布に従うノイズ$\epsilon \sim \mathcal{N}(0,1)$をサンプリングし、$\theta = \mu + \sigma \epsilon$をパラメータとして用います。
$\mu$と$\sigma$を勾配降下法で最適化します。

flipout

上記の手法だと、ミニバッチ内のサンプルで$\epsilon$が共有されてしまう問題があります。
そこで、ノイズにランダムな符号ベクトル$r_ns_n$をかけ、$\theta=\mu + \sigma\epsilon r_ns_n^T$を用いることで、学習の効率化を行います。

回帰タスクの対数尤度

サンプリングされたパラメータ$\theta$でのデータ$x$の予測を$p(y|x,\theta)=\mathcal{N}(f(x|\theta),\sigma^2)$、正解を$y$とすると、対数尤度は$\ln p(y|x,\theta)=-\frac{\ln 2\pi\sigma^2}{2} - \frac{(y-f(x|\theta))^2}{2\sigma^2}$となります。
つまり、負の対数尤度を最小化するためには、2乗誤差を最小化すればよいことになります。

不確実性の算出

近似分布$q(\theta;\xi)$からn個のパラメータ${\theta_1, \cdots, \theta_n}$をサンプリングし、事後分布の期待値と分散を計算します。

$$
\mu(y) = \frac{1}{n}\sum_{i=1}^n f(x|\theta_i)
$$
$$
{\rm Var}(y) = \sigma^2 + \frac{1}{n}\sum_{i=1}^n f(x|\theta_i)^T f(x|\theta_i) - \mu(y)^T \mu(y)
$$

学習のまとめ

Stochastic variational inferenceでは、各ミニバッチで近似分布$q(\theta)$からパラメータ$\theta$をサンプリングし、平均二乗誤差と事前分布と近似分布のKLダイバージェンスを最小化します。
$$
\mathcal{L} = \sum_{i=1}^n (y_n - f(x_n|\theta))^2 + \alpha D_{KL}[q(\theta)||p(\theta)]
$$
ここで、$\alpha$は、2つの項のバランスを調整するパラメータです。

参考資料

  • D. P. Kingma and M. Welling, Auto-Encoding Variational Bayes, ICLR, 2014
  • Y. Wen et al., Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-batches, ICLR, 2018.
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2