1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RBMs Without Tears RBMの再来

Posted at

RBMs Without Tears RBMの再来

RBMs Without Tears
GBRMs
Langevin Sampling
Gibbs-Langevin Sampling
全体の学習アルゴリズム
結果
実装
おまけ

RBMs Without Tears

2022年にHinton先生らが発表した論分。
Gaussian-Bernoulli RBMs Without Tears」を触ってみる。

VAEやGANばっかりが活躍してて毎晩涙で枕を濡らしてるRestricted Boltzmann Machinesがついに泣き止んだそう。

without tears で「工夫」みたいな意味になるの英語むずい

今回のミソ
① GibbsサンプリングにLangevinサンプリング(ランジュバンサンプリング)を組み込む
② 勾配のノルムをクリッピングする
そうした結果、オリジナルVAE並の精度を発揮するようになった。

GAN+VAEには遠く及ばないけど

隠れ層が1層しかないRBMでこの快挙はスゴイ!というのが主張。

ということで!このGibbs-Langevinサンプリングを導入したアルゴリズムを実装する。
※論文とはちょこちょこ変数名変えてる

GBRBMs

Gaussian-Bernoulli Restricted Boltzmann Machinesの定式はこちら。これは基本的なGBRBMsと一緒なので特に問題はないかと。

\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^{\text{T}}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)
- \left(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\right)^{\text{T}}\boldsymbol{Wh} - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\

p(\boldsymbol{v}\ |\ \boldsymbol{h}) &= \mathcal{N}\left(\boldsymbol{v}\ |\ \boldsymbol{Wh} + \boldsymbol{\mu},\ diag(\boldsymbol{\sigma}^2)\right) \\

p(\boldsymbol{h}_j = 1\ |\ \boldsymbol{v}) &= \Big[\text{Sigmoid}\left(\boldsymbol{W}^{\text{T}}\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2} + \boldsymbol{b} \right)\Big]_j
\end{align}

Langevin Sampling

\begin{align}
\boldsymbol{x}^k &= \boldsymbol{x}^{k-1} - \gamma \frac{\partial E(\boldsymbol{x})}{\partial \boldsymbol{v}} + \sqrt{2\gamma}\xi_{k} \\

\xi_k &\sim \mathcal{N}(\xi_k\ |\ \boldsymbol{0}, \boldsymbol{1}) 
\end{align}

で得られるサンプリング法をLangevin Monte Carlo法という。

何にでも使えるけど、今回で言うと第2項の導関数はRBMのエネルギー関数を使う。
$\gamma$ はステップ幅。

Gibbs-Langevin Sampling

通常のギブスサンプリング(Contrastive Divergence法)では、$\boldsymbol{v}^0 = \boldsymbol{v_{data}}$ として、
サンプリングを $K$ 回繰り返した $\boldsymbol{v}^K$ をモデルの出力とする。

\begin{align}
\boldsymbol{v}^k &\sim p(\boldsymbol{v}|\boldsymbol{h}^{k-1}) \\ 

\boldsymbol{h}^{k-1} &\sim p(\boldsymbol{h}|\boldsymbol{v}^{k-1}) 
\end{align}

多くの場合、K=1

今回は、RBMの涙を止めるために $\boldsymbol{v}$ のサンプリングのみ、ランジュバンサンプリングを使う。
このとき、サンプリングの最初 $\boldsymbol{v_0}$ は標準正規分布のノイズにすると良いらしい。
image.png

$\alpha_m$ はコサインスケジューラーで管理。初期値は20。

\alpha_m = \frac{1}{2}\alpha_0\left(1 + \cos\left(\frac{m}{M}\pi\right)\right)

また、$\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})$ は

\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})
= \frac{\boldsymbol{v} - \boldsymbol{\mu} - \boldsymbol{Wh}}{\boldsymbol{\sigma}^2} 

導出は「おまけ」へ

全体の学習アルゴリズム

image.png

注意!!

論文では
・メトロポリス調整をするかどうか
・サンプリングに $\boldsymbol{h}$ を含めるかどうか
議論してるけど、画像生成の実験で、「$\boldsymbol{h}$ は含めて調整無し」が一番精度良かったので
今回の実装でもそうしてる。

結果

39次元のメルケプストラムで75発話(フレーム数80,000弱)で、隠れ層8次元、学習率0.01、ミニバッチ数128、CD法のループ数20、ランジュバンのループ数10、エポック数30で回した結果。

Figure_1.png

全然じゃん!!!

論文ではもっとドデカいデータ数で、ループ数ももっと大きいから
それくらい規模がでかいと有効なのかも。
スモールなプロジェクトだとあんましって感じ?

実装

def sampler(self): 
  sigma2_mean = torch.mean(self.get_var()).detach()
  v = torch.randn(self.nbatch, self.I)
  ph = self.prob_h(v)
  h = torch.bernoulli(ph)

  tmp = cosine_schedule(a0=20 *sigma2_mean, T = self.Lstep)
  for k in range(self.Kstep):    # CD法のループ
    for l in range(self.Lstep):    # ランジュバンサンプリングのループ
      grad_v = self.energy_grad_v(v, h)
      e = torch.randn(v.shape[0], v.shape[1])
      v = v - tmp[l]*grad_v + np.sqrt(2*tmp[l])*e
      
    ph = self.prob_h(v)
    h = torch.bernoulli(ph)

  return v, ph

↑これはさっき載せたアルゴリズムの箇所。


↓その他

##############################################################################
def get_var(self):
  return torch.exp(self.z).clip(min=1e-8)

##############################################################################
def prob_h(self, v):         
  v_ = v / self.get_var()
  ph = torch.matmul(v_, self.W) + self.b
  return torch.sigmoid(ph)

##############################################################################
def prob_v(self, h):
  return torch.matmul(h, self.W.T) + self.mu

##############################################################################
def energy_grad_v(self, v, h):
  sigma2 = self.get_var()
  grad_v = (v - self.mu) - torch.matmul(h, self.W.T)
  return grad_v / sigma2 / v.shape[0]

##############################################################################
# L2ノルムなのかL1ノルムなのか知らんけど、とりあえずL2で。
def norm_clip(self, th=10., norm_type=2): 
  nn.utils.clip_grad_norm_(\
      parameters=self.parameters(), max_norm=th, norm_type=norm_type)

論分に書いてある各パラメータの勾配の導出は「おまけ」へ


pytorchで実装するからロスをちゃんと定義してあげる。

##############################################################################
def compute_loss(self, v, h, v_, h_):
  v = v.clone().detach()
  h = h.clone().detach()
  v_ = v_.clone().detach()
  h_ = h_.clone().detach()

  E = self.energy(v, h)
  E_ = self.energy(v_, h_)
  return -torch.mean(-E + E_) 

##############################################################################
def energy(self, v, h):
  sigma2 = self.get_var()
  unit1 = torch.sum(0.5 * (v - self.mu)**2/sigma2, axis=1)
  unit2 = torch.sum(torch.matmul(v/sigma2, self.W) * h, axis=1)
  unit3 = torch.sum(h * self.b, axis=1)

  return 1./v.shape[0] * (unit1 - unit2 - unit3)

compute_lossの冒頭、detach()して、サンプリングの過程を切り離す。

detach()しないでbackwardしちゃうと
例えば、 $\boldsymbol{h}$ をサンプリングする道中に $p(\boldsymbol{h}_j = 1\ |\ \boldsymbol{v}) = \Big[\text{Sigmoid}\left(\boldsymbol{W}^{\text{T}}\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2} + \boldsymbol{b} \right)\Big]_j$
という計算をしてるせいで、その $W$ とか $b$ とかについても偏微分されちゃうから。


N = train_v.shape[0]
batch_array = np.arange(0, N-nbatch, nbatch)

for epoch in range(nepoch):
  perm = np.random.permutation(N)
  for n in range(len(batch_array)):
    m = batch_array[n]
    batch_v = torch.tensor(train_v[perm[m:m+nbatch]].astype(np.float32))
    batch_ph = self.prob_h(batch_v)
    batch_v_, batch_ph_ = self.sampler()

    self.optimizer.zero_grad()
    loss = self.compute_loss(batch_v, batch_ph, batch_v_, batch_ph_)
    loss.backward() 
    self.norm_clip()
    self.optimizer.step()

皆様のコードはもっとカッコよく書いてますが初心者の自分には可読性低いと思いまして・・・

おまけ

エネルギーのvに関する偏微分

$\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})$ を導出する。

\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^{\text{T}}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)
- \left(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\right)^{\text{T}}\boldsymbol{Wh} - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\

&= \frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2
 - \sum_i \frac{v_i}{\sigma_i^2}U_i - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \hspace{10mm} (\ \boldsymbol{U} = \boldsymbol{Wh}\ )\\

\frac{\partial E}{\partial v_i} 
&= \frac{1}{2} \cdot 2 \frac{v_i - \mu_i}{\sigma_i}\cdot \frac{1}{\sigma_i} - \frac{U_i}{\sigma_i^2} \\

&= \frac{v_i - \mu_i}{\sigma_i^2} - \frac{U_i}{\sigma_i^2} \\

\frac{\partial E}{\partial \boldsymbol{v}} &= \frac{\boldsymbol{v} - \boldsymbol{\mu} - \boldsymbol{Wh}}{\boldsymbol{\sigma}^2}
\end{align}

パラメータ更新について

論分の勾配が符号間違ってると思うんだが・・・
論文では$\nabla\theta = -\frac{\partial E}{\partial \theta}$ って書いてるけどこれだと辻褄合わないんだよな。

以降、アルゴリズム図の中が正しいと仮定すると

\begin{align}
\theta &= \theta - r (\nabla\theta_+ - \nabla\theta_-) \\
&= \theta - r \Bigl(\bigl<\frac{\partial E}{\partial \theta}\bigr>_d - \bigl<\frac{\partial E}{\partial \theta}\bigr>_m\Bigr) \\

&= \theta + r \Bigl(\bigl<-\frac{\partial E}{\partial \theta}\bigr>_d - \bigl<-\frac{\partial E}{\partial \theta}\bigr>_m\Bigr) 
\end{align}

下記、符号は論文と逆だけど辻褄合います。

\begin{align}
\frac{\partial E}{\partial W_{ij}} &= -\frac{v_i}{\sigma_i^2}h_j \\

\frac{\partial E}{\partial \mu_i} &= -\frac{v_i - \mu_i}{\sigma_i^2} \\

\frac{\partial E}{\partial b_j} &= -h_j \\

\frac{\partial E}{\partial z_i} &= -\frac{(v_i-\mu_i)^2}{2\sigma_i^2}+\frac{\sum_jv_iW_{ij}h_j}{\sigma_i^2}
\end{align}

この導出は散々見かけるので割愛。と、見せかけて・・・


\begin{align}
\frac{\partial E}{\partial W_{ij}} 
&= \frac{\partial}{\partial W_{ij}} -\sum_i\frac{v_i}{\sigma_i^2}U_i \\

&= -\frac{v_i}{\sigma_i^2}\frac{\partial}{\partial W_{ij}}\sum_jW_{ij}h_j \\

&= -\frac{v_i}{\sigma_i^2}h_j
\end{align}

\begin{align}
\frac{\partial E}{\partial \mu_i} 
&= \frac{\partial}{\partial \mu_i} \frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2 \\

&= \frac{1}{2} \cdot 2 \frac{v_i - \mu_i}{\sigma_i} \cdot -\frac{1}{\sigma_i} \\

&= \frac{-v_i + \mu_i}{\sigma_i^2} \\
\end{align}

\begin{align}
\frac{\partial E}{\partial b_j} &= \frac{\partial}{\partial b_j}-\boldsymbol{b}^\text{T}\boldsymbol{h} \\

&= \frac{\partial}{\partial b_j}-\sum_j b_jh_j \\
&= -h_j
\end{align}

最後

\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= 
\frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2
 - \sum_i \frac{v_i}{\sigma_i^2}U_i - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \hspace{10mm} (\ \boldsymbol{U} = \boldsymbol{Wh}\ )\\

&= \frac{1}{2}\sum_i(v_i - \mu_i)^2\exp(-z_i) - \sum_i v_iU_i\exp(-z_i) - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\
 
\frac{\partial E}{\partial z_i} 
&= \frac{1}{2}(v_i - \mu_i)^2\exp(-z_i) \cdot (-1) - v_iU_i\exp(-z_i)\cdot (-1) \\

&= -\frac{1}{2}\frac{(v_i-\mu_i)^2}{\exp(-z_i)} + \frac{v_iU_i}{\exp(-z_i)} \\

&= -\frac{1}{2}\frac{(v_i -\mu_i)^2}{\sigma_i^2} + \frac{v_iU_i}{\sigma_i^2} \\

&= -\frac{1}{2}\frac{(v_i -\mu_i)^2}{\sigma_i^2} + \frac{v_i\sum_jW_{ij}h_j}{\sigma_i^2} 
\end{align}

そもそも・・・
シンプルに、論分に書いてある$\nabla\theta$ を使って
$\theta = \theta + \eta\nabla\theta$ でOK。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?