Restricted Boltzmann MachinesをPyTorchの自動微分で実装する
目次
・意気込み
・RBMの概要
- 自由エネルギーによる勾配計算
・実装
- モデルの定義
- 学習
- 結果
・おわり
・おまけ(色々な導出)
意気込み
Restricted Boltzmann Machines (RBM)はみんな大好きな手法で、調べればコードは腐るほどヒットしますね。
ただ、どれも手計算で得られる勾配を足し引きしてパラメータの更新をしてます。
そうするとモデルを弄ったりする度、自分で勾配を計算し直さないといけません。
マジでかったるいです。
なので今回は自動微分を使ってRBMを実装します。
RBMの概要
\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
-\boldsymbol{c}^T\boldsymbol{h}
-\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh} \\
Z(\theta) &= \sum_{\boldsymbol{v}, \boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\
p(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{Z(\theta)}e^{-E(\boldsymbol{v},\boldsymbol{h})} \\
p(h_j=1|\boldsymbol{v}) &= \rm{sigmoid}\left(c_j + \sum_i W_{ij}\frac{v_i}{\sigma^2_i}\right) \\
p(v_i|\boldsymbol{h}) &= \mathcal{N}\left(v_i;\ b_i + \sum_jW_{ij}h_j,\ \sigma^2_i\right)
\end{align}
というルールで、上の3つは順にエネルギー関数、分配関数、同時確率です。
RBMの独立性により、各層の条件付き確率は下2つのようになります。
通常、RBMと言えば可視層$\boldsymbol{v}$ と 隠れ層$\boldsymbol{h}$ はそれぞれ0, 1(ベルヌーイ分布に従う)ですが、本記事ではGaussian-Bernoulli型で、可視層は実数値です。
自由エネルギー
今日お世話になる自由エネルギー $F(\boldsymbol{v})$ と勾配は下記
\begin{align}
\frac{\partial \mathcal{L}}{\partial \theta}
&= -\mathbb{E}_{p_{data}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
+
\mathbb{E}_{p_{model}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr] \\
F(\boldsymbol{v})
&=\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
- \sum_j \log\left(1 + \exp(\lambda_j)\right) \\
\lambda_j
&= c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\end{align}
詳細は「おまけ(色々な導出)」
つまり、入力$\boldsymbol{v}_0$,生成$\boldsymbol{v}_1$に対するエネルギーの誤差を微分したものが勾配になる.
つまりつまり、このエネルギーの誤差がPyTorchで言うLossになる。
実装
モデルの定義
class RBM(nn.Module):
def __init__(self, I=39, J=8, lr=0.001):
super().__init__()
self.I, self.J = I, J
self.lr=lr
scale = 1./(I + J)
rnd = np.random.randn
# setting parameters
self.W = nn.Parameter(torch.Tensor(rnd(I, J)*scale))
self.b = nn.Parameter(torch.Tensor(np.zeros(I)))
self.c = nn.Parameter(torch.Tensor(np.zeros(J)))
self.z = nn.Parameter(torch.Tensor(np.zeros(I)))
self.optimizer = optim.Adam(self.parameters(), lr=lr)
zは分散パラメータの補助パラメータです。
分散は非負制約があるので、$z = \log \sigma^2$として$z$を代わりに学習します。
##############################################################################
def energy(self, v):
sigma2 = torch.exp(self.z)
unit1 = 0.5 * torch.sum((v - self.b)**2/sigma2, axis=1)
tmp = torch.matmul(v/sigma2, self.W) + self.c
tmp = 1 + torch.exp(tmp)
unit2 = torch.sum(torch.log(tmp), axis=1)
return unit1 - unit2
##############################################################################
def encode(self, v):
sigma2 = torch.exp(self.z)
ph = torch.sigmoid(torch.matmul(v/sigma2, self.W) + self.c)
#ph = torch.softmax(torch.matmul(v/sigma2, self.W) + self.c, dim=1)
return ph
##############################################################################
def decode(self, h):
pv = torch.matmul(h, self.W.T) + self.b
return pv
##############################################################################
def feedforward(self, v0):
sigma = torch.sqrt(torch.exp(self.z))
# contrastive-divergence-1
ph0 = self.encode(v0)
h0 = torch.bernoulli(ph0)
pv1 = self.decode(h0)
v1 = torch.randn(v0.shape[0], v0.shape[1])*sigma + pv1
ph1 = self.encode(v1)
h1 = torch.bernoulli(ph1)
return h0, v1, h1
feedforwardのCD法では
$p(h|v)$からきちんと$h$をサンプリングしてから$p(v|h)$を計算します。
よく$p(h|v)$のまま進める人がいますが、これだと分散の学習が上手くいかないです。
学習
h0, batch_v1, h1 = self.feedfoward(batch_v0)
batch_v1 = torch.tensor(batch_v1.detach().numpy())
# compute loss
self.optimizer.zero_grad()
e0 = self.energy(batch_v0)
e1 = self.energy(batch_v1)
loss = - torch.mean(-e0 + e1) # 最小化問題に帰着するためマイナス倍
# optimize
loss.backward()
self.optimizer.step()
学習部分では、
・ Lossである -F(v0) + F(v1)を求める
・そのためにv1を求める
ことをします。
ただし、F(v1)のv1は今まで辿ってきた経路をリセットしておきす。
というのも、さっき見た勾配計算を進めてみれば分かります。v1はv0同様観測値として勾配計算を手でやってみれば
よく見る勾配式に一致します。
結果
39次のメルケプストラムで、学習エポックは30、ミニバッチサイズは64、最適化はAdam(βはデフォルト)、学習率は0.001で。
学習データは各次元、平均0分散1に正規化してます。
RBMならメルケプの再現度はこんなもんでしょうという感じです。
おわり
$p(h|v)$の計算にソフトマックス関数を使うと途中でNaNを吐くようになってお手上げでした。
自動微分を使わず手計算勾配を使ったときは問題なかったし、
今回だって自動微分勾配と手計算勾配とでハイパーつよつよ相関(相関係数ほぼ1)だったので問題ないはずだけども。
エネルギー関数の設計的にソフトマックス関数がアウトなのかも。
おまけ(色々な導出)
自由エネルギーの導出
\begin{align}
p(\boldsymbol{v})
&= \sum_{\boldsymbol{h}}p(\boldsymbol{v}, \boldsymbol{h}) \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}e^{-E(\boldsymbol{v}, \boldsymbol{h})} \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
+\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right) \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
\right)
\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right) \\
&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
\right)
\sum_{\boldsymbol{h}}\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right)
\end{align}
一番右側の $\sum$ の箇所をもう少し整理すると
\begin{align}
\sum_{\boldsymbol{h}}\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right)
&= \sum_{\boldsymbol{h}}\exp\left(
\sum_j \left(
c_jh_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}h_j
\right)\right) \\
&= \sum_{\boldsymbol{h}}\exp\left(
\sum_j h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \\
&= \sum_{\boldsymbol{h}}\prod_j\exp\left(
h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \hspace{15mm}(*)\\
&= \prod_j \left(1 + \exp\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \\
&= \prod_j \left(1 + \exp(\lambda_j)\right) \hspace{15mm} ←括弧内を\lambdaと置いた\\
&= \prod_j \exp\left(\log\left(1 + \exp(\lambda_j)\right)\right) \\
&= \exp\left(\sum_j \log\left(1 + \exp(\lambda_j)\right)\right)
\end{align}
従って、尤度 $p(\boldsymbol{v})$ は
\begin{align}
p(\boldsymbol{v})
&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
\right)\exp\left(\sum_j \log\left(1 + \exp(\lambda_j)\right)\right) \\
&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
+ \sum_j \log\left(1 + \exp(\lambda_j)\right)\right) \\
&= \frac{1}{Z(\theta)}\exp\bigl(-F(\boldsymbol{v})\bigr) \hspace{15mm}←括弧内を-Fと置いた
\end{align}
つまり、$e^{-F(\boldsymbol{v})} = \sum_h e^{-E(\boldsymbol{v},\boldsymbol{h})}$ ということ。
この$F(\boldsymbol{v})$が自由エネルギー。
F(\boldsymbol{v}) = \frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
- \sum_j \log\left(1 + \exp(\lambda_j)\right)
勾配の導出
$p(\boldsymbol{v}) = \frac{1}{Z(\theta)}\exp\bigl(-F(\boldsymbol{v})\bigr)$ より、対数尤度の偏微分は、
\begin{align}
\frac{\partial}{\partial \theta}\log p(\boldsymbol{v})
&= \frac{\partial}{\partial \theta} \bigl( -F(\boldsymbol{v}) - \log Z(\theta) \bigr) \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}Z(\theta) \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}, \boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}}\sum_{\boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}}e^{-F(\boldsymbol{v})} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\sum_{\boldsymbol{v}}\frac{\partial}{\partial \theta}e^{-F(\boldsymbol{v})} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\sum_{\boldsymbol{v}}-e^{-F(\boldsymbol{v})}\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
+ \sum_{\boldsymbol{v}}\frac{e^{-F(\boldsymbol{v})}}{Z(\theta)}\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
+ \sum_{\boldsymbol{v}}p(\boldsymbol{v})\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\
\end{align}
目的化関数は $ \mathcal{L} = \sum_n \log p(\boldsymbol{v}^{(n)}) $ だから、
\begin{align}
\frac{\partial \mathcal{L}}{\partial \theta}
&= \sum_n \Bigl[ -\frac{\partial F(\boldsymbol{v}^{(n)})}{\partial \theta}
+ \sum_{\boldsymbol{v}}p(\boldsymbol{v})\frac{\partial F(\boldsymbol{v})}{\partial \theta} \Bigr] \\
&= -\mathbb{E}_{p_{data}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
+
\mathbb{E}_{p_{model}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
\end{align}
おまけ (*)について
\begin{align}
\sum_{\boldsymbol{h}}\prod_j\exp\left(
h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right)
&= \prod_j \left(1 + \exp\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right)
\end{align}
左辺から右辺を導きます。
RBMでは $h \in $ {$0, 1$} なので、例えばJ=3のとき
$h = [0,0,0]\ [1,0,0]\ [0,1,0]\ [1,1,0]\ [0,0,1]\ [1,0,1]\ [0,1,1]\ [1,1,1]$ になる。
そうすると
($\lambda_j = c_j + (\boldsymbol{v}/\sigma^2)^T\boldsymbol{W}_{:j})$ とおく)
\begin{align}
\sum_{\boldsymbol{h}}\prod_j\exp(h_j\lambda_j)
&= e^0e^0e^0 + e^{\lambda_1}e^0e^0 + e^0e^{\lambda_2}e^0 + e^{\lambda_1}e^{\lambda_2}e^0 + \\
& \hspace{20mm}e^0e^0e^{\lambda_3} + e^{\lambda_1}e^0e^{\lambda_3} + e^0e^{\lambda_2}e^{\lambda_3} + e^{\lambda_1}e^{\lambda_2}e^{\lambda_3} \\
&= 1 + e^{\lambda_1} + e^{\lambda_2} + e^{\lambda_3} + e^{\lambda_1}e^{\lambda_2}
+ e^{\lambda_1}e^{\lambda_3} + e^{\lambda_2}e^{\lambda_3} + e^{\lambda_1}e^{\lambda_2}e^{\lambda_3} \\
&= (1 + e^{\lambda_1})(1 + e^{\lambda_2})(1 + e^{\lambda_3}) \\
&= \prod_j 1 + e^{\lambda_j}
\end{align}
以上。