はじめに
こんにちは @ta-ka です.
この記事は 数学とコンピュータⅡ Advent Calendar 2017 の25日目の記事です.
pythonで制約ボルツマンマシンを実装しました.
教科書として『深層学習』を使いました.
本記事の構成
- はじめに
- 制約ボルツマンマシン
- ボルツマン分布
- 条件付き分布
- 対数尤度関数
- パラメータ更新
- CD法
- 実装
- 結果
- おわりに
制約ボルツマンマシン
下図のような構造を持つ無向グラフを考えます.
$v_{i}$ を可視変数,$h_{j}$ を隠れ変数と呼び,$v_{i}, h_{j}$ は $0$ か $1$ の値をとります.
ボルツマン分布
ボルツマン分布を下式で定義します.
指数部分の値が大きいほど,$\boldsymbol v, \boldsymbol h$ の生起確率が高くなるような分布です.
\begin{align}
p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \bigl(- \Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta) \bigr) \tag{1}
\end{align}
エネルギー関数 $\Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta)$,正規化項 $Z(\boldsymbol \theta)$ は以下のように定義してします.
\begin{align}
\Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta)
&= - \sum_{i} a_{i} v_{i} - \sum_{j} b_{j} h_{j} - \sum_{i} \sum_{j} w_{ij} v_{i} h_{j} \tag{2} \\
Z(\boldsymbol \theta)
&= \sum_{\boldsymbol v} \sum_{\boldsymbol h} \exp \bigl(- \Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta) \bigr) \tag{3}
\end{align}
条件付き分布
可視変数 $\boldsymbol v$ が与えられた時の隠れ変数 $\boldsymbol h$ の条件付き分布 $p(\boldsymbol h \mid \boldsymbol v, \boldsymbol \theta)$ を求めます.
\begin{align}
p(\boldsymbol h \mid \boldsymbol v, \boldsymbol \theta)
&= \cfrac{p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)}{p(\boldsymbol v \mid \boldsymbol \theta)} \\
&= \cfrac{p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)}{\sum_{\boldsymbol h} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)} \tag{4}
\end{align}
分子を以下のように計算します.
\begin{align}
p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \bigl(- \Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta) \bigr) \\
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \left( \sum_{i} a_{i} v_{i} + \sum_{j} b_{j} h_{j} + \sum_{i} \sum_{j} w_{ij} v_{i} h_{j} \right) \\
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \left( \sum_{i} a_{i} v_{i} \right) \exp \left( \sum_{j} h_{j} \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr) \right) \\
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \left( \sum_{i} a_{i} v_{i} \right) \prod_{j} \exp \left(h_{j} \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr) \right) \tag{5}
\end{align}
分母を以下のように計算します.
\begin{align}
\sum_{\boldsymbol h} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta)
&= \sum_{\boldsymbol h} \cfrac{1}{Z(\boldsymbol \theta)} \exp \bigl(- \Phi(\boldsymbol v, \boldsymbol h, \boldsymbol \theta) \bigr) \\
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \left( \sum_{i} a_{i} v_{i} \right) \sum_{\boldsymbol h} \prod_{j} \exp \left(h_{j} \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr) \right) \\
&= \cfrac{1}{Z(\boldsymbol \theta)} \exp \left( \sum_{i} a_{i} v_{i} \right) \prod_{j} \left(1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr) \right) \tag{6}
\end{align}
式$(4)$, $(5)$, $(6)$より以下の式が得られます.
\begin{align}
p(\boldsymbol h \mid \boldsymbol v, \boldsymbol \theta)
&= \prod_{j} \cfrac{\exp \Bigl(h_{j} (b_{j} + \sum_{i} w_{ij} v_{i}) \Bigr)}{1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr)} \tag{7} \\
p(h_{j} = 1 \mid \boldsymbol v, \boldsymbol \theta)
&= \cfrac{\exp \Bigl(b_{j} + \sum_{i} w_{ij} v_{i} \Bigr)}{1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{i} \Bigr)} \\
&= {\rm sigmoid} \left(b_{j} + \sum_{i} w_{ij} v_{i} \right) \tag{8}
\end{align}
同様に, 隠れ変数 $\boldsymbol h$ が与えられた時の可視変数 $\boldsymbol v$ の条件付き分布 $p(\boldsymbol v \mid \boldsymbol h, \boldsymbol \theta)$ が得られます.
\begin{align}
p(\boldsymbol v \mid \boldsymbol h, \boldsymbol \theta)
&= \prod_{i} \cfrac{\exp \Bigl(v_{i} (a_{i} + \sum_{j} w_{ij} h_{j}) \Bigr)}{1 + \exp \Bigl( a_{i} + \sum_{j} w_{ij} h_{j} \Bigr)} \tag{9} \\
p(v_{i} = 1 \mid \boldsymbol h, \boldsymbol \theta)
&= \cfrac{\exp \Bigl(a_{i} + \sum_{j} w_{ij} h_{j} \Bigr)}{1 + \exp \Bigl( a_{i} + \sum_{j} w_{ij} h_{j} \Bigr)} \\
&= {\rm sigmoid} \left(a_{i} + \sum_{j} w_{ij} h_{j} \right) \tag{10}
\end{align}
式$(8)$, $(10)$ はパラメータ更新で利用します.
対数尤度関数
対数尤度関数は以下のように計算されます.
\begin{align}
\ln L &= \ln \prod_{n} p(\boldsymbol v_{n} \mid \boldsymbol \theta) \\
&= \sum_{n} \ln p(\boldsymbol v_{n} \mid \boldsymbol \theta) \\
&= \sum_{n} \left[ \sum_{i} a_{i} v_{ni} + \sum_{j} \ln \left(1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{ni} \Bigr) \right) - \ln Z(\boldsymbol \theta) \right] \tag{11}
\end{align}
パラメータ更新
パラメータの更新式を求めるために
式$(11)$の対数尤度関数を $w_{ij}$, $a_{i}$, $b_{j}$ で偏微分します.
\begin{align}
\cfrac{\partial \ln L}{\partial w_{ij}}
&= \sum_{n} \cfrac{v_{ni} \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{ni} \Bigr)}{1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{ni} \Bigr)} - \cfrac{N}{Z(\boldsymbol \theta)} \sum_{\boldsymbol v} \sum_{\boldsymbol h} v_{i} h_{j} \exp \left( \sum_{i} a_{i} v_{i} + \sum_{j} b_{j} h_{j} + \sum_{i} \sum_{j} w_{ij} v_{i} h_{j} \right) \\
&= \sum_{n} v_{ni} {\rm sigmoid} \left(b_{j} + \sum_{i} w_{ij} v_{ni} \right) - N \sum_{\boldsymbol v} \sum_{\boldsymbol h} v_{i} h_{j} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta) \\
&= \sum_{n} v_{ni} p(h_{nj} = 1 \mid \boldsymbol v_{n}, \boldsymbol \theta) - N \sum_{\boldsymbol v} \sum_{\boldsymbol h} v_{i} h_{j} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta) \tag{12} \\ \\
\cfrac{\partial \ln L}{\partial a_{i}}
&= \sum_{n} v_{ni} - \cfrac{N}{Z(\boldsymbol \theta)} \sum_{\boldsymbol v} \sum_{\boldsymbol h} v_{i} \exp \left( \sum_{i} a_{i} v_{i} + \sum_{j} b_{j} h_{j} + \sum_{i} \sum_{j} w_{ij} v_{i} h_{j} \right) \\
&= \sum_{n} v_{ni} - N \sum_{\boldsymbol v} \sum_{\boldsymbol h} v_{i} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta) \tag{13} \\ \\
\cfrac{\partial \ln L}{\partial b_{j}}
&= \sum_{n} \cfrac{v_{ni} \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{ni} \Bigr)}{1 + \exp \Bigl( b_{j} + \sum_{i} w_{ij} v_{ni} \Bigr)} - \cfrac{N}{Z(\boldsymbol \theta)} \sum_{\boldsymbol v} \sum_{\boldsymbol h} h_{j} \exp \left( \sum_{i} a_{i} v_{i} + \sum_{j} b_{j} h_{j} + \sum_{i} \sum_{j} w_{ij} v_{i} h_{j} \right) \\
&= \sum_{n} v_{ni} {\rm sigmoid} \left(b_{j} + \sum_{i} w_{ij} v_{ni} \right) - N \sum_{\boldsymbol v} \sum_{\boldsymbol h} h_{j} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta) \\
&= \sum_{n} v_{ni} p(h_{nj} = 1 \mid \boldsymbol v_{n}, \boldsymbol \theta) - N \sum_{\boldsymbol v} \sum_{\boldsymbol h} h_{j} p(\boldsymbol v, \boldsymbol h \mid \boldsymbol \theta) \tag{14}
\end{align}
式$(12)$, $(13)$, $(14)$の第 $2$ 項は $v_{i}$, $h_{j}$ のすべての状態における重み付き和であり,計算コストが大きくなります.
例えば可視変数,隠れ変数が合わせて $100$ 個ある場合,状態の通り数は $2^{100}$ 通りになります.
CD法
上述のように式$(12)$, $(13)$, $(14)$の第 $2$ 項は計算コストが大きいため,
ギブスサンプリングを使って近似的に計算することを考えます.
CD法では,初期値 $\boldsymbol v^{(0)} = \boldsymbol v_{n}$ とし,サンプリングを $T$ 回繰り返します.
サンプリングした結果を使ってパラメータを更新します.
\begin{align}
\Delta w_{ij} &= \epsilon \left(v_{i}^{(0)} p_{j}^{(0)} - v_{i}^{(T)} p_{j}^{(T)} \right) \tag{15} \\
\Delta a_{i} &= \epsilon \left(v_{i}^{(0)} - v_{i}^{(T)} \right) \tag{16} \\
\Delta b_{j} &= \epsilon \left(p_{j}^{(0)} - p_{j}^{(T)} \right) \tag{17}
\end{align}
ただし,$p_{j}^{(T)} = p(h_{j} = 1 \mid \boldsymbol v^{(T)}, \boldsymbol \theta)$ です.
(CD法のサンプリングについては理解することができなかったので,後日勉強して追記します.)
実装
RBMによる二値画像の自己符号化器を実装しました.
python3で実装したプログラムを掲載します.
import numpy as np
class RBM:
def __init__(self, n_v, n_h):
self.w = np.random.randn(n_h, n_v)
self.b = np.random.randn(n_h, 1)
self.a = np.random.randn(n_v, 1)
def train(self, V, epsilon, epoch, T):
for epo in range(epoch):
for (n, v_0) in enumerate(V.T):
v_0 = np.copy(v_0).reshape(-1, 1)
p_h_0 = np.copy(self.sigmoid(self.w.dot(v_0) + self.b))
v, p_h = self.encode_decode(np.copy(v_0), T)
self.update(v_0, v, p_h_0, p_h, epsilon)
def encode_decode(self, v, T):
for t in range(T):
# visible
p_h = self.sigmoid(self.w.dot(v) + self.b)
h = (np.random.rand(n_h, 1) < p_h).astype('float64')
# hidden
p_v = self.sigmoid(self.w.T.dot(h) + self.a)
v = (np.random.rand(n_v, 1) < p_v).astype('float64')
return (v, p_h)
def update(self, v_0, v, p_h_0, p_h, epsilon):
self.w += epsilon * (v_0.T * p_h_0 - v.T * p_h)
self.a += epsilon * (v_0 - v)
self.b += epsilon * (p_h_0 - p_h)
def sigmoid(self, x):
return 1.0 / (1.0 + np.exp(-x))
if __name__ == "__main__":
# init parameters
N = 20
side = 8
n_v = side ** 2
n_h = 32
V = np.round(np.random.rand(n_v, N))
T = 5
# train
rbm = RBM(n_v, n_h)
rbm.train(V, 0.1, 200, T)
結果
$20$ サンプルで学習し,復元した画像のうち $10$ サンプルを下図に示します.
左がランダムに生成した $8 \times 8$ の元画像です.
右がRBMにより復元した画像です.隠れ変数の数を $32$ に設定し,$64$ 次元の元画像を復元しています.
おわりに
RBMによる二値画像の自己符号化器を実装できました.
CD法のサンプリングを勉強して理解できたら追記します.