LoginSignup
1
4

More than 1 year has passed since last update.

【PyTorch】RBMを自動微分で実装する

Last updated at Posted at 2023-04-05

Restricted Boltzmann MachinesをPyTorchの自動微分で実装する

目次
意気込み
RBMの概要
 - 自由エネルギーによる勾配計算
実装
 - モデルの定義
 - 学習
 - 結果
おわり
おまけ(色々な導出)

意気込み

Restricted Boltzmann Machines (RBM)はみんな大好きな手法で、調べればコードは腐るほどヒットしますね。
ただ、どれも手計算で得られる勾配を足し引きしてパラメータの更新をしてます。

そうするとモデルを弄ったりする度、自分で勾配を計算し直さないといけません。
マジでかったるいです。


なので今回は自動微分を使ってRBMを実装します。

RBMの概要

\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
-\boldsymbol{c}^T\boldsymbol{h}
-\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh} \\

Z(\theta) &= \sum_{\boldsymbol{v}, \boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\

p(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{Z(\theta)}e^{-E(\boldsymbol{v},\boldsymbol{h})} \\

p(h_j=1|\boldsymbol{v}) &= \rm{sigmoid}\left(c_j + \sum_i W_{ij}\frac{v_i}{\sigma^2_i}\right) \\
p(v_i|\boldsymbol{h}) &= \mathcal{N}\left(v_i;\ b_i + \sum_jW_{ij}h_j,\ \sigma^2_i\right)
\end{align}

というルールで、上の3つは順にエネルギー関数、分配関数、同時確率です。
RBMの独立性により、各層の条件付き確率は下2つのようになります。

通常、RBMと言えば可視層$\boldsymbol{v}$ と 隠れ層$\boldsymbol{h}$ はそれぞれ0, 1(ベルヌーイ分布に従う)ですが、本記事ではGaussian-Bernoulli型で、可視層は実数値です。

自由エネルギー

今日お世話になる自由エネルギー $F(\boldsymbol{v})$ と勾配は下記

\begin{align}
\frac{\partial \mathcal{L}}{\partial \theta} 
&= -\mathbb{E}_{p_{data}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
+
\mathbb{E}_{p_{model}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr] \\

F(\boldsymbol{v}) 
&=\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2 
 - \sum_j \log\left(1 + \exp(\lambda_j)\right) \\

\lambda_j 
&= c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\end{align}

詳細は「おまけ(色々な導出)

つまり、入力$\boldsymbol{v}_0$,生成$\boldsymbol{v}_1$に対するエネルギーの誤差を微分したものが勾配になる.
つまりつまり、このエネルギーの誤差がPyTorchで言うLossになる。

実装

モデルの定義

class RBM(nn.Module):
  def __init__(self, I=39, J=8, lr=0.001):
    super().__init__()
    self.I, self.J = I, J
    self.lr=lr

    scale = 1./(I + J)
    rnd = np.random.randn

    # setting parameters
    self.W = nn.Parameter(torch.Tensor(rnd(I, J)*scale))
    self.b = nn.Parameter(torch.Tensor(np.zeros(I)))
    self.c = nn.Parameter(torch.Tensor(np.zeros(J)))
    self.z = nn.Parameter(torch.Tensor(np.zeros(I)))

    self.optimizer = optim.Adam(self.parameters(), lr=lr)

zは分散パラメータの補助パラメータです。
分散は非負制約があるので、$z = \log \sigma^2$として$z$を代わりに学習します。

 ##############################################################################
  def energy(self, v):
    sigma2 = torch.exp(self.z)

    unit1 = 0.5 * torch.sum((v  - self.b)**2/sigma2, axis=1)

    tmp = torch.matmul(v/sigma2, self.W) + self.c
    tmp = 1 + torch.exp(tmp)
    unit2 = torch.sum(torch.log(tmp), axis=1)

    return unit1 - unit2

  ##############################################################################
  def encode(self, v):                
    sigma2 = torch.exp(self.z)
    ph = torch.sigmoid(torch.matmul(v/sigma2, self.W) + self.c)
    #ph = torch.softmax(torch.matmul(v/sigma2, self.W) + self.c, dim=1)
    return ph

  ##############################################################################
  def decode(self, h):
    pv = torch.matmul(h, self.W.T) + self.b
    return pv

  ##############################################################################
  def feedforward(self, v0):
    sigma = torch.sqrt(torch.exp(self.z))

    # contrastive-divergence-1
    ph0 = self.encode(v0)
    h0 = torch.bernoulli(ph0)
    pv1 = self.decode(h0)
    v1 = torch.randn(v0.shape[0], v0.shape[1])*sigma + pv1
    ph1 = self.encode(v1)
    h1 = torch.bernoulli(ph1)
    return h0, v1, h1

feedforwardのCD法では
$p(h|v)$からきちんと$h$をサンプリングしてから$p(v|h)$を計算します。
よく$p(h|v)$のまま進める人がいますが、これだと分散の学習が上手くいかないです。

学習

   h0, batch_v1, h1 = self.feedfoward(batch_v0)
   batch_v1 = torch.tensor(batch_v1.detach().numpy())

   # compute loss
   self.optimizer.zero_grad()                         
   e0 = self.energy(batch_v0)
   e1 = self.energy(batch_v1)
        
   loss = - torch.mean(-e0 + e1)  # 最小化問題に帰着するためマイナス倍

   # optimize
   loss.backward()
   self.optimizer.step()

学習部分では、
・ Lossである -F(v0) + F(v1)を求める
・そのためにv1を求める
ことをします。

ただし、F(v1)のv1は今まで辿ってきた経路をリセットしておきす。
というのも、さっき見た勾配計算を進めてみれば分かります。v1はv0同様観測値として勾配計算を手でやってみれば
よく見る勾配式に一致します。

結果

スクリーンショット 2021-08-30 10.34.27.png

39次のメルケプストラムで、学習エポックは30、ミニバッチサイズは64、最適化はAdam(βはデフォルト)、学習率は0.001で。
学習データは各次元、平均0分散1に正規化してます。

RBMならメルケプの再現度はこんなもんでしょうという感じです。

おわり

$p(h|v)$の計算にソフトマックス関数を使うと途中でNaNを吐くようになってお手上げでした。
自動微分を使わず手計算勾配を使ったときは問題なかったし、
今回だって自動微分勾配と手計算勾配とでハイパーつよつよ相関(相関係数ほぼ1)だったので問題ないはずだけども。

エネルギー関数の設計的にソフトマックス関数がアウトなのかも。

おまけ(色々な導出)

自由エネルギーの導出

\begin{align}
p(\boldsymbol{v}) 
&= \sum_{\boldsymbol{h}}p(\boldsymbol{v}, \boldsymbol{h}) \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}e^{-E(\boldsymbol{v}, \boldsymbol{h})} \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}\exp\left( 
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
+\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right) \\
&= \frac{1}{Z(\theta)}\sum_{\boldsymbol{h}}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2
\right)
\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right) \\
 
&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2 
\right)
\sum_{\boldsymbol{h}}\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right)
\end{align}

一番右側の $\sum$ の箇所をもう少し整理すると

\begin{align}
\sum_{\boldsymbol{h}}\exp\left(
\boldsymbol{c}^T\boldsymbol{h}
+\biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{Wh}
\right) 

&= \sum_{\boldsymbol{h}}\exp\left(
\sum_j \left(
c_jh_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}h_j
\right)\right) \\

&= \sum_{\boldsymbol{h}}\exp\left(
\sum_j h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \\

&= \sum_{\boldsymbol{h}}\prod_j\exp\left(
h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \hspace{15mm}(*)\\
&= \prod_j \left(1 + \exp\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right) \\
 
&= \prod_j \left(1 + \exp(\lambda_j)\right) \hspace{15mm} ←括弧内を\lambdaと置いた\\
&= \prod_j \exp\left(\log\left(1 + \exp(\lambda_j)\right)\right) \\
&= \exp\left(\sum_j \log\left(1 + \exp(\lambda_j)\right)\right)
\end{align}

従って、尤度 $p(\boldsymbol{v})$ は

\begin{align}
p(\boldsymbol{v})
&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2 
\right)\exp\left(\sum_j \log\left(1 + \exp(\lambda_j)\right)\right) \\

&= \frac{1}{Z(\theta)}\exp\left(
-\frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2 
 + \sum_j \log\left(1 + \exp(\lambda_j)\right)\right) \\

&= \frac{1}{Z(\theta)}\exp\bigl(-F(\boldsymbol{v})\bigr) \hspace{15mm}←括弧内を-Fと置いた
\end{align}

つまり、$e^{-F(\boldsymbol{v})} = \sum_h e^{-E(\boldsymbol{v},\boldsymbol{h})}$ ということ。

この$F(\boldsymbol{v})$が自由エネルギー。

F(\boldsymbol{v}) = \frac{1}{2}\biggl|\biggl|\frac{\boldsymbol{v} - \boldsymbol{b}}{\boldsymbol{\sigma}}\biggr|\biggr|^2 
 - \sum_j \log\left(1 + \exp(\lambda_j)\right)

勾配の導出

$p(\boldsymbol{v}) = \frac{1}{Z(\theta)}\exp\bigl(-F(\boldsymbol{v})\bigr)$ より、対数尤度の偏微分は、

\begin{align}
\frac{\partial}{\partial \theta}\log p(\boldsymbol{v})
&= \frac{\partial}{\partial \theta} \bigl( -F(\boldsymbol{v}) - \log Z(\theta) \bigr) \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}Z(\theta) \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}, \boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}}\sum_{\boldsymbol{h}}\rm{e}^{-E(\boldsymbol{v},\boldsymbol{h})} \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\frac{\partial}{\partial \theta}\sum_{\boldsymbol{v}}e^{-F(\boldsymbol{v})} \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\sum_{\boldsymbol{v}}\frac{\partial}{\partial \theta}e^{-F(\boldsymbol{v})} \\
 
&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
- \frac{1}{Z(\theta)}\sum_{\boldsymbol{v}}-e^{-F(\boldsymbol{v})}\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
+ \sum_{\boldsymbol{v}}\frac{e^{-F(\boldsymbol{v})}}{Z(\theta)}\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\

&= -\frac{\partial F(\boldsymbol{v})}{\partial \theta}
+ \sum_{\boldsymbol{v}}p(\boldsymbol{v})\frac{\partial F(\boldsymbol{v})}{\partial \theta} \\
\end{align}

目的化関数は $ \mathcal{L} = \sum_n \log p(\boldsymbol{v}^{(n)}) $ だから、

\begin{align}
\frac{\partial \mathcal{L}}{\partial \theta} 
&= \sum_n \Bigl[ -\frac{\partial F(\boldsymbol{v}^{(n)})}{\partial \theta}
+ \sum_{\boldsymbol{v}}p(\boldsymbol{v})\frac{\partial F(\boldsymbol{v})}{\partial \theta} \Bigr] \\

&= -\mathbb{E}_{p_{data}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
+
\mathbb{E}_{p_{model}}\Bigl[\frac{\partial F(\boldsymbol{v})}{\partial \theta}\Bigr]
\end{align}

おまけ (*)について

\begin{align}
\sum_{\boldsymbol{h}}\prod_j\exp\left(
h_j\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right)
&= \prod_j \left(1 + \exp\left(
c_j + \biggl(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\biggr)^T\boldsymbol{W_{:j}}
\right)\right)
\end{align}

左辺から右辺を導きます。

RBMでは $h \in $ {$0, 1$} なので、例えばJ=3のとき
$h = [0,0,0]\ [1,0,0]\ [0,1,0]\ [1,1,0]\ [0,0,1]\ [1,0,1]\ [0,1,1]\ [1,1,1]$ になる。
そうすると
($\lambda_j = c_j + (\boldsymbol{v}/\sigma^2)^T\boldsymbol{W}_{:j})$ とおく)

\begin{align}
\sum_{\boldsymbol{h}}\prod_j\exp(h_j\lambda_j)
&= e^0e^0e^0 + e^{\lambda_1}e^0e^0 + e^0e^{\lambda_2}e^0 + e^{\lambda_1}e^{\lambda_2}e^0 + \\
& \hspace{20mm}e^0e^0e^{\lambda_3} + e^{\lambda_1}e^0e^{\lambda_3} + e^0e^{\lambda_2}e^{\lambda_3} + e^{\lambda_1}e^{\lambda_2}e^{\lambda_3} \\

&= 1 + e^{\lambda_1} + e^{\lambda_2} + e^{\lambda_3} + e^{\lambda_1}e^{\lambda_2}
+ e^{\lambda_1}e^{\lambda_3} + e^{\lambda_2}e^{\lambda_3} +  e^{\lambda_1}e^{\lambda_2}e^{\lambda_3}  \\

&= (1 + e^{\lambda_1})(1 + e^{\lambda_2})(1 + e^{\lambda_3}) \\
&= \prod_j 1 + e^{\lambda_j}
\end{align}

以上。

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4