はじめに
雰囲気でベイズ更新について語っていたので、きちんと勉強し直したいと思ってまとめました。n番煎じ記事です。
参考
- 5分でわかる(範囲の)ベイズ統計学(https://www.slideshare.net/matsukenbook/15-59154892)
- 確率分布のベイズ推定(http://arduinopid.web.fc2.com/P19.html)
- ココイチのスプーンの当選確率をベイズ推定する(https://rmizutaa.hatenablog.com/entry/2019/02/15/200700)
手順
ベイズ更新
ベイズ推定をざっくりまとめると、尤度(=モデル)のパラメータをMLEで点推定するのではなく、ベイズの定理を用いてパラメータの分布を推定するというものです。ベイズの定理は次の式で表されます。
p(w|x) = \frac{p(x|w)p(w)}{p(x)}
$p(x|w)$が尤度(=モデル)、$p(w)$が事前分布、$p(x)$がデータの発生確率です。$p(w|x)$を事後分布と言います。
この式は、モデルのパラメータ$w$の事前分布に尤度をかけることで、データ$x$が与えられたときのパラメータ$w$の分布(事後分布)が得られるという式になっています。
ベイズ更新はベイズの定理を使用して、逐次的に事後分布を更新する仕組みです。以下の展開はこちらを参考にしました。
まず、ベイズの定理の式変形を行います。
\begin{align}
&p(A, B, C) = p(A|B, C)p(B, C) = p(B, C|A)p(A) \\
\Rightarrow& p(A|B, C) = \frac{p(B, C|A)p(A)}{p(B, C)}
\end{align}
ここまでは一般に成り立ちます。
次に、BとCは独立かつAを与えたもとで条件付き独立であるとします。すると次の式が得られます。
\begin{align}
p(A|B, C) &= \frac{p(B, C|A)p(A)}{p(B, C)} \\
&= \frac{p(B|A)p(C|A)p(A)}{p(B)p(C)} \\
&= \frac{p(C|A)}{p(C)}\frac{p(B|A)p(A)}{p(B)} \\
&= \frac{p(C|A)}{p(C)}p(A|B)
\end{align}
さて、Aに$w$を、Bに最初に観測されたデータ$x_1$を、Cに二つ目に観測されたデータ$x_2$を入れてみます。
「$x_1$と$x_2$は独立かつ$w$を与えたもとで条件付き独立」とは、観測されたデータたちは独立に分布から生成されており、かつ尤度(=モデル)も$x_1$と$x_2$が独立になるようにモデリングするということを意味します。この特性は時系列に相関があるようなデータでは成り立たないことに注意してください。
このとき、次の式が成り立ちます。
\begin{align}
p(w|x_1, x_2) = \frac{p(x_2|w)}{p(x_2)}p(w|x_1)
\end{align}
データ$x_1$と$x_2$が与えられたもとでの$w$の事後分布は、$p(w|x_1)$を事前分布にもつ、$x_2$と$w$のベイズの定理になっています!
この関係を利用して、適当な事前分布を与える→データが与えられたら事後分布を求める→事後分布を事前分布に読み替える→データが与えられたら事後分布を求める→...と逐次的に事後分布を更新することができます。これがベイズ更新の仕組みです。
ベイズ更新の例
こちらを参考にしてベイズ更新を計算してみます。
歪んだコインを考えます。このコインで表が出る確率をベイズ推定してみましょう。コインの出目はそれぞれの試行で独立に得られるとします。
まず、コインの出目について尤度、すなわちモデルを考えます。コインの表裏は二値変数なのでベルヌーイ分布に従います。したがって尤度はベルヌーイ分布になります。表を$x=1$、裏を$x=0$に対応させます。$w$は0から1の間の実数で、表が出る確率そのものに対応します。したがって、ベイズ更新で求められる$w$の分布そのものが、表が出る確率の分布になります。
ベルヌーイ分布は次のように書けます。
\begin{align}
p(x|w) = w^x(1-w)^{1-x}
\end{align}
事前分布は無情報事前分布として$p(w)=1$としましょう。$p(x)$は未知ですが、$w$に注目した場合は規格化定数なので、$p(w|x)$の積分が1になる条件から自動的に求めることができます。
では、コインを振っていきます。
1回目:表
\begin{align}
p(w|x_1=1) \propto p(x_1=1|w)p(w) = w
\end{align}
0から1まで積分すると1になるので、すでに規格化されています。
\begin{align}
p(w|x_1=1) = w
\end{align}
これが2回目の事前分布$p(w) = p(w|x_1=1) = w$になります。
2回目:表
\begin{align}
p(w|x_1=1, x_2=1) \propto p(x_2=1|w)p(w) = w^2
\end{align}
積分すると、1/3になるので、規格化すると次のようになります。
\begin{align}
p(w|x_1=1, x_2=1) = 3w^2
\end{align}
これが3回目の事前分布$p(w) = p(w|x_1=1, x_2=1) = 3w^2$になります。
3回目:裏
\begin{align}
p(w|x_1=1, x_2=1, x_3=0) \propto p(x_3=0|w)p(w) = 3(1-w)w^2
\end{align}
積分すると、1/4になるので、規格化すると次のようになります。
\begin{align}
p(w|x_1=1, x_2=1, x_3=0) = 12(1-w)w^2
\end{align}
このようにして、コインが表になる確率$w$の分布を求めることができます。
ベイズ更新の例の実装
上で見た例をPythonで実装してみましょう。実装はこちらを参考にしました。
import sympy
import numpy as np
import matplotlib.pyplot as plt
np.random.rand()
# コインの表裏の系列
xs = [1, 1, 0] # 表、表、裏
# 事前分布は無情報事前分布とします
prior_prob = 1
# 積分するシンボル
w = sympy.Symbol('w')
# 初期化しておきます
# jupyter notebookで繰りし実行する場合に必要です
posterior_prob = None
# 事後分布の逐次計算
for x in xs:
# 事後分布(未規格化)を計算
if x==1:
posterior_prob = w*prior_prob
else:
posterior_prob = (1-w)*prior_prob
# 規格化
Z = sympy.integrate(posterior_prob, (w, 0, 1))
posterior_prob = posterior_prob/Z
# 事前分布の置き換え
prior_prob = posterior_prob
plt.figure(figsize=(5, 4))
X = np.linspace(0, 1, 100)
plt.plot(X, [posterior_prob.subs(w, i) for i in X])
plt.xlabel("w")
plt.show()
おまけ
もう少しデータがある場合の推定をやってみます。
表が出る確率が0.35のコインで、試行回数を30回に増やして推定してみます。
データが増えるほど推定が正確になっていくので、分布がシャープになっていきます。
import sympy
import numpy as np
import matplotlib.pyplot as plt
np.random.rand(0)
def bernoulli_sampler(w, n):
"""wは1が出る確率で、nは生成するデータの数"""
xs = np.random.rand(n)
xs = xs<w
return xs.astype("int")
# コインの表裏の系列
xs = bernoulli_sampler(0.35, 30)
# 事前分布は無情報事前分布とします
prior_prob = 1
# 積分するシンボル
w = sympy.Symbol('w')
# 初期化しておきます
# jupyter notebookで繰りし実行する場合に必要です
posterior_prob = None
# 事後分布の逐次計算
for x in xs:
# 事後分布(未規格化)を計算
if x==1:
posterior_prob = w*prior_prob
else:
posterior_prob = (1-w)*prior_prob
# 規格化
Z = sympy.integrate(posterior_prob, (w, 0, 1))
posterior_prob = posterior_prob/Z
# 事前分布の置き換え
prior_prob = posterior_prob
plt.figure(figsize=(5, 4))
X = np.linspace(0, 1, 100)
plt.plot(X, [posterior_prob.subs(w, i) for i in X])
plt.xlabel("w")
plt.show()
おわりに
BとCは独立かつAを与えたもとで条件付き独立であるとする部分がポイントです。