LoginSignup
0
0

LogisticTS方策におけるラプラス近似した事後分布(ガウス分布)の導出

Last updated at Posted at 2024-02-27

この記事でやること

LogisticTS方策におけるパラメータのサンプリングに必要なラプラス近似した事後分布(ガウス分布)の期待値ベクトルと共分散行列を導出する.

Notation

記号 説明
$t \in {1,..,T}$ 時刻$t$
$i \in \mathcal{I}$ アクション集合
$X_i(t)$ 時刻$t$に観測されたアクション$i$からの報酬
$a_{i,t} \in \mathbb{R}^d$ アクション$i$と時刻$t$に依存した$d$次元の文脈ベクトル
$a_{i,t,k} \in \mathbb{R}$ アクション$i$と時刻$t$に依存した文脈ベクトルの$k$番目の要素
$\theta_i \in \mathbb{R}^d$ $d$次元のアクション$i$の真のパラメータベクトル
$\theta_{i,k} \in \mathbb{R}$ アクション$i$の真のパラメータベクトルの$k$番目の要素
$a_{i(s),s} \in \mathbb{R}^d$ 時刻$s$時点で選ばれたアクション$i$と時刻$s$に依存した$d$次元の文脈ベクトル
$a_{i(s),s, k} \in \mathbb{R}$ 時刻$s$時点で選ばれたアクション$i$と時刻$s$に依存した文脈ベクトルの$k$番目の要素
$I_d$ $d$次元の単位行列

文脈付きバンディットにおける2値報酬モデル

\begin{align}
\sigma(x) &= \frac{1}{1 + \exp(-x)} \\[10pt]
p_{i,t} &= \sigma(\theta_{i}^{\top}a_{i,t}) \\[10pt]
X_i(t) &\overset{\mathrm{i.i.d}}{\sim} \mathcal{Be}(p_{i,t})
\end{align}

以上のように,報酬が2値の場合はアクション特有のパラメータ$\theta_i$と時刻$t$で観測された文脈(特徴量)の内積をシグモイド関数に通す. つまり, 確率変数である報酬$X_i(t)$はパラメータ$p_{i,t}$のベルヌーイ分布に従う.よって, 期待値と分散は以下の通り.

\begin{align}
\mathbb{E}_{X_i(t)} \left[ X_i(t) \right] &= p_{i,t} \\[5pt]
\mathbb{V}_{X_i(t)} \left[ X_i(t) \right] &= p_{i,t}(1-p_{i,t})
\end{align}

LogisticTS方策とは

観測される報酬が離散値かつ文脈付きバンディット版のTS方策. LinTS方策と同じく以下のように, パラメータ$\theta_i$に事前分布を導入することで, 時刻$t$までに収集された報酬のもとでのパラメータ$\theta_i$の事後分布をもとにサンプリングを行い, 線形モデル$\theta_{i}^{\top}a_{i,t}$の内積値最大のアクションを選択する方策である.

\begin{align}
p(\theta_i) &= \mathcal{N}(\theta_i | \mathbf{0},\sigma_{0}^2I_d) \\
&= \frac{1}{(2\pi)^{\frac{d}{2}} |\sigma_{0}^2I_d|^{\frac{1}{2}}} \exp \left(- \frac{\theta_{i}^{\top} \theta_i}{2\sigma_{0}^2} \right)
\end{align}

LinTS方策との違い

観測される報酬に仮定する確率分布が異なる.つまり, LinTS方策では報酬$X_i(t)$はガウス分布に従う.LogisticTS方策では,報酬$X_i(t)$はベルヌーイ分布(カテゴリカル分布)に従う.そして, パラメータ$\theta_i$が事前分布としてガウス分布に従うならば, LogisticTS方策における事後分布は陽に定まらない。なので, 分布を近似したりして事後分布を求めることが実装上のポイントとなる.

ラプラス近似

LogisticTS方策における事後分布の近似の手段の一つがラプラス近似である.ラプラス近似では, 事後確率分布のMAP解の近傍で2次のテイラー展開をすることでガウス分布を近似する. そして,近似した分布からパラメータをサンプリングする.

\begin{align}
p(\theta_i | \{ X_i(s) \}_{s=1}^t) &\approx \mathcal{N}(\theta_i | \mu_t,\Sigma_{t}) \\
&\approx  \frac{1}{(2\pi)^{\frac{d}{2}} |\Sigma_t|^{\frac{1}{2}}} \exp \left(- \frac{1}{2} \left(\theta_i - \mu_t \right)^{\top} \Sigma_{t}^{-1} \left(\theta_i - \mu_t \right) \right) \\
-\log p(\theta_i | \{ X_i(s) \}_{s=1}^t) &\approx \frac{1}{2} \left(\theta_i - \mu_t \right)^{\top} \Sigma_{t}^{-1} \left(\theta_i - \mu_t \right) + \frac{d}{2}\log(2\pi)+\frac{1}{2}\log(|\Sigma_t|)
\end{align}

対数関数は単調増加関数なので,両辺に適用しても最適解は変わらない.そして,

\mathcal{J}_t(\theta_i) = -\log p(\theta_i | \{ X_i(s) \}_{s=1}^t)

と置き換え, MAP解$\hat{\theta}_{i}^{\text{MAP}}$近傍での$\mathcal{J}_t(\theta_i)$の2次テイラー展開は以下の通り.

\begin{align}
\mathcal{J}_t(\theta_i) &\approx \frac{1}{2}(\theta_i - \hat{\theta_i}^{\text{MAP}})^{\top} \nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})(\theta_i - \hat{\theta_i}^{\text{MAP}}) + \nabla \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})^{\top}(\theta_i - \hat{\theta_i}^{\text{MAP}}) + \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}}) \\
&\approx \frac{1}{2}(\theta_i - \hat{\theta_i}^{\text{MAP}})^{\top} \nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})(\theta_i - \hat{\theta_i}^{\text{MAP}}) + \text{Const}
\end{align}

この近似式の第一項目は見慣れた多変量ガウス分布の指数部分になっている.

\mu_t = \hat{\theta}_{i}^{\text{MAP}}, \quad \Sigma_t = \nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})^{-1} \\
\tilde{\theta}_i \sim \mathcal{N}(\mu_t, \Sigma_t) 

よって, ガウス分布の期待値ベクトルをMAP解$\hat{\theta}_{i}^{\text{MAP}}$, 共分散行列を$\mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})$のヘッセ行列$\nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})$の逆行列として, サンプリングするということがLogisticTS方策における事後分布のラプラス近似である.
ということで, サンプリングに必要な事後分布のMAP解$\hat{\theta_i}^{\text{MAP}}$と負の対数尤度関数のヘッセ行列$\nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})$を導出することが今回のテーマである.

事後分布の導出

負の対数尤度関数を最小化する基準でMAP解を求めることを目指す.

\begin{align}
p(\theta_i | \{ X_i(s) \}_{s=1}^t) &\propto p(\theta_i)P(\{ X_i(s) \}_{s=1}^t | \theta_i) \\
&= p(\theta_i) \prod_{s=1}^t p(X_i(t)|\theta_i) \\
&= \frac{1}{(2\pi)^{\frac{d}{2}} |\sigma_{0}^2I_d|^{\frac{1}{2}}} \exp \left(- \frac{\theta_{i}^{\top} \theta_i}{2\sigma_{0}^2} \right) \prod_{s=1}^t p_{i(s),s}^{X_i(s)} \left( 1 - p_{i(s),s}\right)^{(1-X_i(s))} \\[15px]
- \log p(\theta_i | \{ X_i(s) \}_{s=1}^t) &\propto - \sum_{s=1}^{t} X_i(s) \log(p_{i(s),s}) + (1 - X_i(s))\log(1-p_{i(s),s}) + \frac{1}{2\sigma_{0}^2}\| \theta_i \|_{2}^{2} + \frac{d}{2}\log(2\pi)+\frac{1}{2}\log(|\sigma_{0}^2I_d|) \\
&\propto - \sum_{s=1}^{t} X_i(s) \log(p_{i(s),s}) + (1 - X_i(s))\log(1-p_{i(s),s}) + \frac{1}{2\sigma_{0}^2}\| \theta_i \|_{2}^{2} \\
\end{align}

式変形を繰り返すと, 事後分布に負の対数を取った関数は見慣れた正則化付きクロスエントロピー損失と比例する.なので,

\mathcal{J}_t(\theta_i) = - \sum_{s=1}^{t} X_i(s) \log(p_{i(s),s}) + (1 - X_i(s))\log(1-p_{i(s),s}) + \frac{1}{2\sigma_{0}^2}\| \theta_i \|_{2}^{2}
\hat{\theta_i}^{\text{MAP}} = \underset{\theta_i}{\text{argmin}} \quad \mathcal{J}_t(\theta_i) 

として,損失を最小にすればよい. ただし,クロスエントロピー損失では,最適解が一点に定まるとは限らないため数値的に解く必要がある.今回は,後にヘッセ行列をサンプリングに使いたいのでニュートン法を用いる.

\hat{\theta_i}^{(l+1)} = \hat{\theta_i}^{(l)} - \nabla^2 \mathcal{J}_t(\hat{\theta_i}^{(l)})^{-1}\nabla \mathcal{J}_t(\hat{\theta_i}^{(l)})

勾配ベクトルの導出

正則化付きクロスエントロピー損失の勾配を求めれば良い.

\nabla \mathcal{J}_t(\theta_{i}) = \sum_{s=1}^{t}(p_{i(s),s}-X_i(s))a_{i(s),s} + \frac{1}{\sigma_{0}^2}\theta_i

ヘッセ行列の導出

ヘッセ行列は対称行列なので, 対角要素とそれ以外を導出すれば必要十分.

\begin{align}
\frac{\partial \mathcal{J}_t(\theta_i)}{\partial^2 \theta_{i,k}^2} &= \sum_{s=1}^t p_{i(s),s} (1-p_{i(s),s}) a_{i(s),s,k}^2 + \frac{1}{\sigma_0^2} \\
\frac{\partial \mathcal{J}_t(\theta_{i})}{\partial \theta_{i,k} \partial \theta_{i,j}} &= \sum_{s=1}^t p_{i(s),s} (1-p_{i(s),s}) a_{i(s),s,k} a_{i(s),s,j}
\end{align}

よって、$\mathcal{J}_t(\theta_i)$のヘッセ行列は以下の通り.

\nabla^2 \mathcal{J}_t(\theta_i) = \sum_{s=1}^t p_{i(s),s} (1-p_{i(s),s}) a_{i(s),s}a_{i(s),s}^{\top}  + \frac{1}{\sigma_0^2}I_d

これらの勾配ベクトルとヘッセ行列を用いて, 何かしらの停止条件を満たすまで反復を続け,MAP解を得る.

近似した事後分布の再掲

\begin{align}
\mu_t &= \hat{\theta_i}^{\text{MAP}} \\
\Sigma_t &= \nabla^2 \mathcal{J}_t(\hat{\theta_i}^{\text{MAP}})^{-1} \\
&= \left( \sum_{s=1}^t p_{i(s),s} (1-p_{i(s),s}) a_{i(s),s}a_{i(s),s}^{\top}  + \frac{1}{\sigma_0^2}I_d \right)^{-1}\\[20px]
\end{align}
p(\theta_i | \{ X_i(s) \}_{s=1}^t) \approx \mathcal{N}(\theta_i | \mu_t,\Sigma_{t})

参考文献

・本多淳也,中村篤祥:バンディット問題の理論とアルゴリズム,講談社(2016)
・元田浩ほか訳:パターン認識と機械学習上,ベイズ理論による統計的予測,シュプリンガー・ジャパン (2007)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0