ロジスティック回帰のラプラス近似によるベイズ推定
ロジスティック回帰
特徴量(説明変数)ベクトル$\boldsymbol{x}_n$に対するバイナリ値を取る目的変数値$y_n \in \{0,1\}$の組が$N$個得られているとします。
ロジスティック回帰は$(\boldsymbol{x}_n,y_n)$を
\begin{align}
y_n &\sim {\rm Bern}(y_n; \sigma(\boldsymbol{x}_n^T \boldsymbol{w})) \\
\sigma(x) &= \frac{1}{1+\exp(-x)}
\end{align}
とモデル化します。$\boldsymbol{w}$は係数ベクトルパラメータ、$\sigma(x)$はシグモイド関数で、${\rm Bern}(y;\pi)$は確率$\pi$で$y=1$、確率$1-\pi$で$y=0$となるベルヌーイ分布:
\begin{align}
{\rm Bern}(y;\pi) &= \pi^y (1-\pi)^{1-y}
\end{align}
です。
事後分布のラプラス近似
学習データを$\mathcal{D}_N \equiv {(\boldsymbol{x}_n,y_n)}_{n=1}^N$とすると対数尤度は
\begin{align}
\ln p(\mathcal{D}_N|\boldsymbol{w})
&= \ln \prod_{n=1}^N {\rm Bern}(y_n; \sigma(\boldsymbol{x}_n^T \boldsymbol{w})) \\
&= \sum_{n=1}^N \left\{
y_n \ln \sigma(\boldsymbol{x}_n^T \boldsymbol{w}) + (1-y_n) \ln (1 - \sigma(\boldsymbol{x}_n^T \boldsymbol{w}))
\right\}
\end{align}
です。対数を取ったときに$\boldsymbol{w}$についてこのような関数形になる(名前の付いた)分布はありませんね。。よって予測分布が解析的に求まるとか、既存のサンプラーを使うことのできるような共役事前分布はありません。
そこでここでは事後分布を正規分布で近似(ラプラス近似)することにします。事前分布も簡単のため正規分布とします。
\begin{align}
p(\boldsymbol{w}) &= \mathcal{N}(\boldsymbol{w};\boldsymbol{\mu}_0,\boldsymbol{\Sigma}_0)
\end{align}
事後分布の対数は
\begin{align}
\ln p(\boldsymbol{w}|\mathcal{D}_N)
&= \ln p(\mathcal{D}_N|\boldsymbol{w}) + \ln p(\boldsymbol{w}) + {\rm const.} \\
&= \ln \prod_{n=1}^N {\rm Bern}(y_n; \sigma(\boldsymbol{x}_n^T \boldsymbol{w})) + \ln \mathcal{N}(\boldsymbol{w};\boldsymbol{\mu}_0,\boldsymbol{\Sigma}_0) +{\rm const.} \\
&= \sum_{n=1}^N \left\{
y_n \ln \sigma(\boldsymbol{x}_n^T \boldsymbol{w}) + (1-y_n) \ln (1 - \sigma(\boldsymbol{x}_n^T \boldsymbol{w}))
\right\} \\
&\hspace{4mm}
- \frac{1}{2} (\boldsymbol{w} - \boldsymbol{\mu}_0)^T \boldsymbol{\Sigma}_0^{-1} (\boldsymbol{w} - \boldsymbol{\mu}_0) + {\rm const.}
\end{align}
です。分布の確率密度関数の値が一番大きくなる値をモード(最頻値)と言います。事後分布のモードは事後確率を最大化する値であり、MAP(Maximum A Posteriori)推定値と言います。事後分布のモードを$\hat{\boldsymbol{w}}$として、事後分布の対数を$\hat{\boldsymbol{w}}$の周りで二次近似すると、モードでは一階微分は$\boldsymbol{0}$なので
\begin{align}
\ln p(\boldsymbol{w}|\mathcal{D}_N)
&\approx \ln p(\hat{\boldsymbol{w}}|\mathcal{D}_N) + \frac{1}{2} (\boldsymbol{w}-\hat{\boldsymbol{w}})^T \left(\frac{\partial^2}{\partial \boldsymbol{w} \boldsymbol{w}^T} \ln p(\boldsymbol{w}|\mathcal{D}_N) \bigg|_{\boldsymbol{w}=\hat{\boldsymbol{w}}} \right) (\boldsymbol{w}-\hat{\boldsymbol{w}}) \\
&= \ln p(\hat{\boldsymbol{w}}|\mathcal{D}_N) + \frac{1}{2} (\boldsymbol{w}-\hat{\boldsymbol{w}})^T \boldsymbol{H}(\hat{\boldsymbol{w}}) (\boldsymbol{w}-\hat{\boldsymbol{w}}) \\
\boldsymbol{H}(\hat{\boldsymbol{w}})
&\equiv \frac{\partial^2}{\partial \boldsymbol{w} \boldsymbol{w}^T} \ln p(\boldsymbol{w}|\mathcal{D}_N) \bigg|_{\boldsymbol{w}=\hat{\boldsymbol{w}}}
\end{align}
となります。これは$\boldsymbol{w}$についての二次形式なので、事後分布を正規分布で近似していることになります。
\begin{align}
p(\boldsymbol{w}|\mathcal{D}_N)
&\approx \mathcal{N}(\boldsymbol{w};\boldsymbol{\mu}_N, \boldsymbol{\Sigma}_N) \\
\boldsymbol{\mu}_N
&\equiv \hat{\boldsymbol{w}} \\
\boldsymbol{\Sigma}_N
&\equiv -\boldsymbol{H}(\hat{\boldsymbol{w}})^{-1}
\end{align}
このように事後分布を、モード周りで正規分布で近似することをラプラス近似と呼びます。$\boldsymbol{H}(\boldsymbol{w})$のような二階導関数行列をヘッセ行列と呼びます。$\boldsymbol{H}(\hat{\boldsymbol{w}})$は「事後分布の対数」のヘッセ行列の、モード$\hat{\boldsymbol{w}}$での値です。
ニュートン法による学習
次にモード$\hat{\boldsymbol{w}}$の求め方です。勾配法では
\begin{align}
\hat{\boldsymbol{w}}^{(t+1)} &\leftarrow \hat{\boldsymbol{w}}^{(t)} + \alpha \boldsymbol{g}(\hat{\boldsymbol{w}}^{(t)}) \\
\boldsymbol{g}(\boldsymbol{w}) &\equiv \frac{\partial}{\partial \boldsymbol{w}} \ln p(\boldsymbol{w}|\mathcal{D}_N)
\end{align}
と更新します(ここでは最大化のため$\alpha \boldsymbol{g}(\boldsymbol{w})$の符号は$+$になっています)。このとき$\alpha$は学習率で、勾配法ではパラメータとなります。
勾配法は勾配が$\boldsymbol{0}$になるところで収束します。関数値が$0$となるところを求める高速なアルゴリズムにニュートン法があります。簡単のため一変数関数で考えると、関数$f(x)$の$x^{(t)}$での一次近似は
\begin{align}
f(x) = f(x^{(t)}) + f^\prime(x^{(t)}) (x - x^{(t)})
\end{align}
であり、$f(x)=0$を解くと
\begin{align}
x = x^{(t)} - \frac{f(x^{(t)})}{f^\prime(x^{(t)})}
\end{align}
です。この更新を反復して、$f(x)=0$の解を求めるのがニュートン法です。
今回の場合は勾配$\boldsymbol{g}(\boldsymbol{w})$が$\boldsymbol{0}$となるところを求めるので、勾配の微分、すなわち目的関数(事後分布の対数)の二階微分$\boldsymbol{H}(\boldsymbol{w})$が出てきて、ニュートン法の更新式は
\begin{align}
\hat{\boldsymbol{w}}^{(t+1)} &\leftarrow \hat{\boldsymbol{w}}^{(t)} - \boldsymbol{H}(\hat{\boldsymbol{w}}^{(t)})^{-1} \boldsymbol{g}(\hat{\boldsymbol{w}}^{(t)})
\end{align}
となります。ニュートン法が収束したとき、ニュートン法の更新で計算しているヘッセ行列が、そのまま事後分布の共分散行列を求めるのに利用できます。
ニュートン法とラプラス近似は相性の良い組み合わせですね。
次回は「ベイズ推定6:周辺尤度」です。