12
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ベータ分布でベイズ推定するときの事前分布パラメータ評価

Last updated at Posted at 2018-12-19

はじめに

親愛なる皆様、今日も元気にベイズってますでしょうか?
ベイズ推定ってわりと直感的で、バンディット問題なんかと相性よく私も実務で使いはじめているのですが、いつでも悩みどころになるのが事前分布の設計だと思います。
とりあえず無情報事前分布と呼ばれるもの使っておこうとか、経験上このくらいの期待値と分散っぽいからエイっと決めちゃおうとか、みんなわりとそんなノリかなという気がしているのですが、職人芸は不安になりますよね。
では依って立つ理論はないのかというと、WAICという情報量規準を考案された渡辺先生の定理が光を照らしてくれるので、そのまま引用します。

定理 : 「(自然な条件を満たす)任意の統計モデルと任意の事前分布について,ベイズ推測の汎化誤差と自由エネルギーはあるシンプルな数学的法則に従っている」
...
与えられた「統計モデル+事前分布」の良さ悪さは定量的に計算できる
...
(1) 定理のユーザーは,自分が用いている学習モデルと事前分布について,その汎化誤差(予測精度)を知ることができます.
(2) 定理のユーザーは,自分が用いている学習モデルと事前分布について,その自由エネルギーを知ることができます.
(3) 定理のユーザーは,汎化誤差あるいは自由エネルギーが小さくなるように,学習モデルと事前分布を決めることができます.

式展開

さてベイズ推定で最も簡単なのが、尤度関数がベルヌーイ分布で共役事前分布をベータ分布にした場合ですね。ここでベータ分布とは、

  • ベータ関数
B (a, b) = \int^1_0 w^{a - 1} (1 - w)^{b - 1} dw
  • ベータ分布
f (w; a, b) = \frac{w^{a - 1} (1 - w)^{b - 1}}{B(a, b)}

というものでした。
さて、初期分布を $f (w; a_0, b_0)$ として、正例 $a_1$ 負例 $b_1$ が観測されたとしましょう。このとき自由エネルギーを式展開してみると、

  • 自由エネルギー
\begin{align}
F_n (\beta)
&= - \frac{1}{\beta} \log \int^1_0 \frac{w^{a_0 - 1} (1 - w)^{b_0 - 1}}{B(a_0, b_0)} w^{a_1 \beta} (1 - w)^{b_1 \beta} dw \\
&= - \frac{1}{\beta} \log \int^1_0 \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} \frac{B(a_0 + a_1 \beta, b_0 + b_1 \beta)}{B(a_0, b_0)} dw \\
&= - \frac{1}{\beta} \log \frac{B(a_0 + a_1 \beta, b_0 + b_1 \beta)}{B(a_0, b_0)} \\
&= \frac{1}{\beta} ( \log B(a_0, b_0) - log B(a_0 + a_1 \beta, b_0 + b_1 \beta) )
\end{align}

というわけで、ベータ分布は単純な形をしているため綺麗に式展開できちゃうみたいなんですよね。(計算ミスしてたら教えてください)
つづけて汎化誤差を求めてみたいところなのですが、汎化誤差は真の分布を知らないと計算不能なので、汎化誤差の推定に使える損失を式展開していきます。

  • 学習損失
\begin{align}
T_n
&= - \frac{1}{a_1 + b_1} (a_1 \log \int^1_0 w \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} dw + b_1 \log \int^1_0 (1 - w) \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} dw) \\
&= - \frac{1}{a_1 + b_1} (a_1 \log \frac{B(a_0 + a_1 \beta + 1, b_0 + b_1 \beta)}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} + b_1 \log \frac{B(a_0 + a_1 \beta, b_0 + b_1 \beta + 1)}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)}) \\
&= \frac{a_1}{a_1 + b_1} (\log B(a_0 + a_1 \beta, b_0 + b_1 \beta) - log B(a_0 + a_1 \beta + 1, b_0 + b_1 \beta)) \\
&+ \frac{b_1}{a_1 + b_1} (\log B(a_0 + a_1 \beta, b_0 + b_1 \beta) - log B(a_0 + a_1 \beta, b_0 + b_1 \beta + 1))
\end{align}
  • Leave-one-out Cross Validation Loss
\begin{align}
LOOCV_n
&= - \frac{1}{a_1 + b_1} (a_1 \log \int^1_0 w \frac{w^{a_0 + (a_1 - 1) \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + (a_1 - 1) \beta, b_0 + b_1 \beta)} dw + b_1 \log \int^1_0 (1 - w) \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + (b_1 - 1) \beta - 1}}{B(a_0 + a_1 \beta, b_0 + (b_1 - 1) \beta)} dw) \\
&= - \frac{1}{a_1 + b_1} (a_1 \log \frac{B(a_0 + (a_1 - 1) \beta + 1, b_0 + b_1 \beta)}{B(a_0 + (a_1 - 1) \beta, b_0 + b_1 \beta)} + b_1 \log \frac{B(a_0 + a_1 \beta, b_0 + (b_1 - 1) \beta + 1)}{B(a_0 + a_1 \beta, b_0 + (b_1 - 1) \beta)}) \\
&= \frac{a_1}{a_1 + b_1} (\log B(a_0 + (a_1 - 1) \beta, b_0 + b_1 \beta) - log B(a_0 + (a_1 - 1) \beta + 1, b_0 + b_1 \beta)) \\
&+ \frac{b_1}{a_1 + b_1} (\log B(a_0 + a_1 \beta, b_0 + (b_1 - 1) \beta) - log B(a_0 + a_1 \beta, b_0 + (b_1 - 1) \beta + 1))
\end{align}
  • Importance Sampling Cross Validation Loss
\begin{align}
ISCV_n
&= \frac{1}{a_1 + b_1} (a_1 \log \int^1_0 \frac{1}{w} \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} dw + b_1 \log \int^1_0 \frac{1}{(1 - w)} \frac{w^{a_0 + a_1 \beta - 1} (1 - w)^{b_0 + b_1 \beta - 1}}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} dw) \\
&= \frac{1}{a_1 + b_1} (a_1 \log \frac{B(a_0 + a_1 \beta - 1, b_0 + b_1 \beta)}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)} + b_1 \log \frac{B(a_0 + a_1 \beta, b_0 + b_1 \beta - 1)}{B(a_0 + a_1 \beta, b_0 + b_1 \beta)}) \\
&= \frac{a_1}{a_1 + b_1} (\log B(a_0 + a_1 \beta - 1, b_0 + b_1 \beta) - log B(a_0 + a_1 \beta, b_0 + b_1 \beta)) \\
&+ \frac{b_1}{a_1 + b_1} (\log B(a_0 + a_1 \beta, b_0 + b_1 \beta - 1) - log B(a_0 + a_1 \beta, b_0 + b_1 \beta))
\end{align}

……渡辺先生の名前を出しておいて恐縮なのですが、WAICは解析的に解くのは難しそうでした。
ただ一般的には計算量が大変なことになるLOOCVがベータ分布の場合はスッと求まるので、とりあえずLOOCVで汎化誤差を推定するのが良さそうです。

ライブラリ

ということで、あとはベータ関数の対数さえ計算できればという感じなのですが、これは各数値計算ライブラリに実装があって、
Pythonなら scipy.special.betaln
Java系なら org.apache.commons.math3.special.Beta.logBeta
あたりを使えばいいと思います。

逆温度が1の場合

と、ここまで書いてきてアレなのですが、逆温度は1に固定することが多く、

LOOCV_n = - \frac{a_1}{a_1 + b_1} \log \frac{B(a_0 + a_1, b_0 + b_1)}{B(a_0 + a_1 - 1, b_0 + b_1)} - \frac{b_1}{a_1 + b_1} \log \frac{B(a_0 + a_1, b_0 + b_1)}{B(a_0 + a_1, b_0 + b_1 - 1)}

ここでベータ関数の性質として、

(a+b) B (a + 1, b) = (a + b) \int^1_0 w^a (1 - w)^{b - 1} dw = - [w^a (1 - w)^b]^1_0 + a \int^1_0 w^{a - 1} (1 - w)^{b - 1} dw = a B(a, b) \\
(a+b) B (a, b + 1) = (a + b) \int^1_0 w^{a - 1} (1 - w)^b dw = [w^a (1 - w)^b]^1_0 + b \int^1_0 w^{a - 1} (1 - w)^{b - 1} dw = b B(a, b)

よってベータ関数を消せて、

LOOCV_n = - \frac{a_1}{a_1 + b_1} \log \frac{a_0 + a_1 - 1}{a_0 + a_1 + b_0 + b_1 - 1} - \frac{b_1}{a_1 + b_1} \log \frac{b_0 + b_1 - 1}{a_0 + a_1 + b_0 + b_1 - 1}

さらに下限を抑えることができて、

LOOCV_n \geqq  \min_{x + y \leqq 1} (- \frac{a_1}{a_1 + b_1} \log x - \frac{b_1}{a_1 + b_1} \log y)

これはクロスエントロピーとみなせるので、

\min_{x + y \leqq 1} (- \frac{a_1}{a_1 + b_1} \log x - \frac{b_1}{a_1 + b_1} \log y) = - \frac{a_1}{a_1 + b_1} \log \frac{a_1}{a_1 + b_1} - \frac{b_1}{a_1 + b_1} \log \frac{b_1}{a_1 + b_1}

よってなるべく大きい $a_0 : b_0 \approx a_1 : b_1$ が LOOCV を小さく、すなわち事前分布を期待値 $\frac{a_1}{a_1 + b_1}$ 分散 0 のパルス(超関数)に近づけていけば LOOCV は小さくなることが分かります。学習損失やISCVでも同様の議論ができますね。

……これは明らかにオーバーフィットしている発散解で何の面白みもないのですが、たとえば同じ事前分布から複数の事後分布を作ってそれぞれの損失をまとめて評価するような場合は、損失最小化でまともな分布に収束したりします。
学習損失・LOOCV・ISCVのまとめ方としては、学習データ数の重み付き平均を取れば良いと思います。

おさらい

以下、渡部先生の『ベイズ統計の理論と方法』と『X.ベイズ統計の理論・モデリング・評価について』講座テキストを参考に、今回の式展開に関わることをメモしておきます。

定義

  • 真の分布 $q ( x )$
  • 学習データ $X^n$
  • 確率モデル $p ( x | w )$
  • 事前分布 $\varphi (w)$
  • 逆温度 $\beta$
  • 分配関数
Z_n (\beta) = \int \varphi (w) \prod^{n}_{i=1} p( X_i | w)^\beta dw
  • 周辺尤度
p (X^n) = Z_n(1) =\int \varphi (w) \prod^{n}_{i=1} p( X_i | w) dw
  • 事後分布
p (w | X^n) = \frac{1}{Z_n (\beta)} \varphi (w) \prod^{n}_{i=1} p( X_i | w)^\beta
  • 予測分布
\begin{align}
p^* (x)
&= \mathbb{E}_w [ p( X | w ) ] \\
&= \int p (x | w) p( w | X^n ) dw
\end{align}
  • KL情報量
D_{KL} ( q || p) = \int q(x) \log \frac{q (x)}{p (x)} dx 
  • 自由エネルギー
F_n(\beta) = - \frac{1}{\beta} \log Z_n(\beta)
  • 汎化損失
\begin{align}
G_n
&= - \mathbb{E}_X [\log \mathbb{E}_w [ p( X | w ) ] ] \\
&= - \int q (x) \log p^* (x) dx
\end{align}

ここで平均を取った $X$ は、学習データ $X^n$ とは独立とする。

  • 学習損失 / 経験損失
\begin{align}
T_n
&= - \frac{1}{n} \sum^n_{i=1} \log \mathbb{E}_w [ p ( X_i | w ) ] \\
&= - \frac{1}{n} \sum^n_{i=1} \log p^* (X_i)
\end{align}
  • Leave-one-out Cross Validation Loss
\begin{align}
LOOCV_n
&= - \frac{1}{n} \sum^n_{i=1} \log ( \int p (X_i | w) p( w | X^n - X_i ) dw ) \\
&= - \frac{1}{n} \sum^n_{i=1} \log ( \int p (X_i | w) \frac{\varphi (w) \prod_{x \in X^n - X_i} p( x | w)^\beta}{\int \varphi (w) \prod_{x \in X^n - X_i} p( x | w)^\beta dw} dw )
\end{align}
  • Importance Sampling Cross Validation Loss
\begin{align}
ISCV_n
&= \frac{1}{n} \sum^n_{i=1} \log \mathbb{E}_w [ \frac{1}{p ( X_i | w )} ] \\
&= \frac{1}{n} \sum^n_{i=1} \log ( \int \frac{1}{p (X_i | w)} p( w | X^n) dw ) \\
&= \frac{1}{n} \sum^n_{i=1} \log ( \int \frac{1}{p (X_i | w)} \frac{\varphi (w) \prod_{x \in X^n} p( x | w)^\beta}{\int \varphi (w) \prod_{x \in X^n} p( x | w)^\beta dw} dw )
\end{align}
  • WAIC (Widely Applicable Information Criterion)
\begin{align}
W_n
&= T_n + \frac{1}{n} \sum^n_{i=1} \mathbb{V}_w [\log p ( X_i | w ) ] \\
&= \frac{1}{n} \sum^n_{i=1} - \log \mathbb{E}_w [ p ( X_i | w ) ] + \mathbb{E}_w [ ( \log p ( X_i | w ) )^2 ] - \mathbb{E}_w[ \log p ( X_i | w ) ]^2 \\
\end{align}

関係性

  • 自由エネルギーとKL情報量

$\beta = 1$ ならば、「学習データについての自由エネルギーの平均」は「真の分布のエントロピーの学習データ数倍」と「「真の分布から発生した学習データの確率分布」の「事前分布で平均した学習データの確率分布」に対するKL情報量」の和。

\begin{align}
\mathbb{E}_{X^n} [F_n (1)]
&= \mathbb{E}_{X^n} \left[- \log p(X^n) \right] \\
&= \mathbb{E}_{X^n} \left[- \log q(X^n) + \log \frac{q(X^n)}{p(X^n)} \right] \\
&= - \int q(x^n ) \sum^n_{i=1} \log q(x_i)  dx^n + \int q(x^n ) \log \frac{q(x^n)}{p(x^n)} dx^n \\
&= - n \int q(x) \log q(x) dx + D_{KL}\left(q(x^n) || p(x^n)\right)
\end{align}
  • 自由エネルギーと汎化損失

$\beta = 1$ ならば、自由エネルギーは汎化損失の和。

\begin{align}
\mathbb{E}_{X^n} [G_n]
&= \mathbb{E}_{X^n} \left[- \int q (x) \log p^* (x) dx \right] \\
&= \mathbb{E}_{X^n} \left[- \int\log \left( \int p (x | w) \frac{1}{Z_n (\beta)} \varphi (w) \prod^{n}_{i=1} p( X_i | w) dw  \right)  q (x) dx \right] \\
&= \mathbb{E}_{X^{n+1}} \left[- \log \frac{Z_{n+1} (1)}{Z_n (1)} \int \frac{1}{Z_{n+1} (\beta)} \varphi (w) \prod^{n+1}_{i=1} p( X_i | w) dw  \right] \\
&= \mathbb{E}_{X^{n+1}} \left[- \log Z_{n+1} (1) + \log Z_n (1) \right] \\
&= \mathbb{E}_{X^{n+1}} \left[F_{n+1} (1) \right] - \mathbb{E}_{X^n} \left[F_n (1) \right] \\
\\
\mathbb{E}_{X^n} \left[F_n (1) \right]
&= \sum^{n-1}_{i=1} \mathbb{E}_{X^i} [G_i]
\end{align}
  • 汎化損失とKL情報量

汎化損失は、「真の分布のエントロピー」と「真の分布の予測分布に対するKL情報量」の和。

\begin{align}
G_n
&= - \int q (x) \log p^* (x) dx \\
&= - \int q(x) \log q(x)  dx + \int q(x) \log \frac{q(x)}{p^*(x)} dx \\
&= - \int q(x) \log q(x)  dx + D_{KL}\left(q(x) || p^*(x)\right)
\end{align}
  • 汎化損失と学習損失

数値実験例:ベイズ法の汎化損失と学習損失』が分かりやすい。

  • 汎化損失と Leave-one-out Cross Validation Loss
\begin{align}
\mathbb{E}_{X^n} [ LOOCV_n ]
&= \int \cdots  \int \left( - \frac{1}{n} \sum^n_{i=1} \log \int p ( x_i | w) p( w | x^n - x_i ) dw \right) q(x_1) dx_1 \cdots q(x_n) dx_n \\
&= - \frac{1}{n} \sum^n_{i=1} \int \cdots  \int \left( \log \int p ( x_i | w) p( w | x^n - x_i ) dw \right) 
 q(x_1) dx_1 \cdots q(x_n) dx_n \\
&= - \frac{1}{n} \sum^n_{i=1} \int \left( \log \int p ( x_i | w) p( w | x^n - x_i ) dw \right) 
 q(x_i) dx_i \\
&= - \frac{1}{n} \sum^n_{i=1} \mathbb{E}_{X^{n-1}} [ G_{n - 1} ] \\
&= \mathbb{E}_{X^{n-1}} [ G_{n - 1} ]
\end{align}
  • Leave-one-out Cross Validation Loss と Importance Sampling Cross Validation Loss

$\beta = 1$ ならば等しい。

\begin{align}
LOOCV_n
&= - \frac{1}{n} \sum^n_{i=1} log( \frac{\int \varphi (w) \prod_{x \in X^n} p( x | w) dw}{\int \varphi (w) \prod_{x \in X^n - X_i} p( x | w) dw} )  \\

ISCV_n
&= \frac{1}{n} \sum^n_{i=1} log( \frac{\int \varphi (w) \prod_{x \in X^n - X_i} p( x | w) dw}{\int \varphi (w) \prod_{x \in X^n} p( x | w) dw} )
\end{align}
12
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?