RBMs Without Tears RBMの再来
・RBMs Without Tears
・GBRMs
・Langevin Sampling
・Gibbs-Langevin Sampling
・全体の学習アルゴリズム
・結果
・実装
・おまけ
RBMs Without Tears
2022年にHinton先生らが発表した論分。
「Gaussian-Bernoulli RBMs Without Tears」を触ってみる。
VAEやGANばっかりが活躍してて毎晩涙で枕を濡らしてるRestricted Boltzmann Machinesがついに泣き止んだそう。
without tears で「工夫」みたいな意味になるの英語むずい
今回のミソ
① GibbsサンプリングにLangevinサンプリング(ランジュバンサンプリング)を組み込む
② 勾配のノルムをクリッピングする
そうした結果、オリジナルVAE並の精度を発揮するようになった。
GAN+VAEには遠く及ばないけど
隠れ層が1層しかないRBMでこの快挙はスゴイ!というのが主張。
ということで!このGibbs-Langevinサンプリングを導入したアルゴリズムを実装する。
※論文とはちょこちょこ変数名変えてる
GBRBMs
Gaussian-Bernoulli Restricted Boltzmann Machinesの定式はこちら。これは基本的なGBRBMsと一緒なので特に問題はないかと。
\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^{\text{T}}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)
- \left(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\right)^{\text{T}}\boldsymbol{Wh} - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\
p(\boldsymbol{v}\ |\ \boldsymbol{h}) &= \mathcal{N}\left(\boldsymbol{v}\ |\ \boldsymbol{Wh} + \boldsymbol{\mu},\ diag(\boldsymbol{\sigma}^2)\right) \\
p(\boldsymbol{h}_j = 1\ |\ \boldsymbol{v}) &= \Big[\text{Sigmoid}\left(\boldsymbol{W}^{\text{T}}\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2} + \boldsymbol{b} \right)\Big]_j
\end{align}
Langevin Sampling
\begin{align}
\boldsymbol{x}^k &= \boldsymbol{x}^{k-1} - \gamma \frac{\partial E(\boldsymbol{x})}{\partial \boldsymbol{v}} + \sqrt{2\gamma}\xi_{k} \\
\xi_k &\sim \mathcal{N}(\xi_k\ |\ \boldsymbol{0}, \boldsymbol{1})
\end{align}
で得られるサンプリング法をLangevin Monte Carlo法という。
何にでも使えるけど、今回で言うと第2項の導関数はRBMのエネルギー関数を使う。
$\gamma$ はステップ幅。
Gibbs-Langevin Sampling
通常のギブスサンプリング(Contrastive Divergence法)では、$\boldsymbol{v}^0 = \boldsymbol{v_{data}}$ として、
サンプリングを $K$ 回繰り返した $\boldsymbol{v}^K$ をモデルの出力とする。
\begin{align}
\boldsymbol{v}^k &\sim p(\boldsymbol{v}|\boldsymbol{h}^{k-1}) \\
\boldsymbol{h}^{k-1} &\sim p(\boldsymbol{h}|\boldsymbol{v}^{k-1})
\end{align}
多くの場合、K=1
今回は、RBMの涙を止めるために $\boldsymbol{v}$ のサンプリングのみ、ランジュバンサンプリングを使う。
このとき、サンプリングの最初 $\boldsymbol{v_0}$ は標準正規分布のノイズにすると良いらしい。
$\alpha_m$ はコサインスケジューラーで管理。初期値は20。
\alpha_m = \frac{1}{2}\alpha_0\left(1 + \cos\left(\frac{m}{M}\pi\right)\right)
また、$\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})$ は
\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})
= \frac{\boldsymbol{v} - \boldsymbol{\mu} - \boldsymbol{Wh}}{\boldsymbol{\sigma}^2}
導出は「おまけ」へ
全体の学習アルゴリズム
注意!!
論文では
・メトロポリス調整をするかどうか
・サンプリングに $\boldsymbol{h}$ を含めるかどうか
議論してるけど、画像生成の実験で、「$\boldsymbol{h}$ は含めて調整無し」が一番精度良かったので
今回の実装でもそうしてる。
結果
39次元のメルケプストラムで75発話(フレーム数80,000弱)で、隠れ層8次元、学習率0.01、ミニバッチ数128、CD法のループ数20、ランジュバンのループ数10、エポック数30で回した結果。
全然じゃん!!!
論文ではもっとドデカいデータ数で、ループ数ももっと大きいから
それくらい規模がでかいと有効なのかも。
スモールなプロジェクトだとあんましって感じ?
実装
def sampler(self):
sigma2_mean = torch.mean(self.get_var()).detach()
v = torch.randn(self.nbatch, self.I)
ph = self.prob_h(v)
h = torch.bernoulli(ph)
tmp = cosine_schedule(a0=20 *sigma2_mean, T = self.Lstep)
for k in range(self.Kstep): # CD法のループ
for l in range(self.Lstep): # ランジュバンサンプリングのループ
grad_v = self.energy_grad_v(v, h)
e = torch.randn(v.shape[0], v.shape[1])
v = v - tmp[l]*grad_v + np.sqrt(2*tmp[l])*e
ph = self.prob_h(v)
h = torch.bernoulli(ph)
return v, ph
↑これはさっき載せたアルゴリズムの箇所。
↓その他
##############################################################################
def get_var(self):
return torch.exp(self.z).clip(min=1e-8)
##############################################################################
def prob_h(self, v):
v_ = v / self.get_var()
ph = torch.matmul(v_, self.W) + self.b
return torch.sigmoid(ph)
##############################################################################
def prob_v(self, h):
return torch.matmul(h, self.W.T) + self.mu
##############################################################################
def energy_grad_v(self, v, h):
sigma2 = self.get_var()
grad_v = (v - self.mu) - torch.matmul(h, self.W.T)
return grad_v / sigma2 / v.shape[0]
##############################################################################
# L2ノルムなのかL1ノルムなのか知らんけど、とりあえずL2で。
def norm_clip(self, th=10., norm_type=2):
nn.utils.clip_grad_norm_(\
parameters=self.parameters(), max_norm=th, norm_type=norm_type)
論分に書いてある各パラメータの勾配の導出は「おまけ」へ
pytorchで実装するからロスをちゃんと定義してあげる。
##############################################################################
def compute_loss(self, v, h, v_, h_):
v = v.clone().detach()
h = h.clone().detach()
v_ = v_.clone().detach()
h_ = h_.clone().detach()
E = self.energy(v, h)
E_ = self.energy(v_, h_)
return -torch.mean(-E + E_)
##############################################################################
def energy(self, v, h):
sigma2 = self.get_var()
unit1 = torch.sum(0.5 * (v - self.mu)**2/sigma2, axis=1)
unit2 = torch.sum(torch.matmul(v/sigma2, self.W) * h, axis=1)
unit3 = torch.sum(h * self.b, axis=1)
return 1./v.shape[0] * (unit1 - unit2 - unit3)
compute_lossの冒頭、detach()して、サンプリングの過程を切り離す。
detach()しないでbackwardしちゃうと
例えば、 $\boldsymbol{h}$ をサンプリングする道中に $p(\boldsymbol{h}_j = 1\ |\ \boldsymbol{v}) = \Big[\text{Sigmoid}\left(\boldsymbol{W}^{\text{T}}\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2} + \boldsymbol{b} \right)\Big]_j$
という計算をしてるせいで、その $W$ とか $b$ とかについても偏微分されちゃうから。
N = train_v.shape[0]
batch_array = np.arange(0, N-nbatch, nbatch)
for epoch in range(nepoch):
perm = np.random.permutation(N)
for n in range(len(batch_array)):
m = batch_array[n]
batch_v = torch.tensor(train_v[perm[m:m+nbatch]].astype(np.float32))
batch_ph = self.prob_h(batch_v)
batch_v_, batch_ph_ = self.sampler()
self.optimizer.zero_grad()
loss = self.compute_loss(batch_v, batch_ph, batch_v_, batch_ph_)
loss.backward()
self.norm_clip()
self.optimizer.step()
皆様のコードはもっとカッコよく書いてますが初心者の自分には可読性低いと思いまして・・・
おまけ
エネルギーのvに関する偏微分
$\nabla_\boldsymbol{v}E(\boldsymbol{v}, \boldsymbol{h})$ を導出する。
\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &= \frac{1}{2}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^{\text{T}}\left(\frac{\boldsymbol{v}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)
- \left(\frac{\boldsymbol{v}}{\boldsymbol{\sigma}^2}\right)^{\text{T}}\boldsymbol{Wh} - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\
&= \frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2
- \sum_i \frac{v_i}{\sigma_i^2}U_i - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \hspace{10mm} (\ \boldsymbol{U} = \boldsymbol{Wh}\ )\\
\frac{\partial E}{\partial v_i}
&= \frac{1}{2} \cdot 2 \frac{v_i - \mu_i}{\sigma_i}\cdot \frac{1}{\sigma_i} - \frac{U_i}{\sigma_i^2} \\
&= \frac{v_i - \mu_i}{\sigma_i^2} - \frac{U_i}{\sigma_i^2} \\
\frac{\partial E}{\partial \boldsymbol{v}} &= \frac{\boldsymbol{v} - \boldsymbol{\mu} - \boldsymbol{Wh}}{\boldsymbol{\sigma}^2}
\end{align}
パラメータ更新について
論分の勾配が符号間違ってると思うんだが・・・
論文では$\nabla\theta = -\frac{\partial E}{\partial \theta}$ って書いてるけどこれだと辻褄合わないんだよな。
以降、アルゴリズム図の中が正しいと仮定すると
\begin{align}
\theta &= \theta - r (\nabla\theta_+ - \nabla\theta_-) \\
&= \theta - r \Bigl(\bigl<\frac{\partial E}{\partial \theta}\bigr>_d - \bigl<\frac{\partial E}{\partial \theta}\bigr>_m\Bigr) \\
&= \theta + r \Bigl(\bigl<-\frac{\partial E}{\partial \theta}\bigr>_d - \bigl<-\frac{\partial E}{\partial \theta}\bigr>_m\Bigr)
\end{align}
下記、符号は論文と逆だけど辻褄合います。
\begin{align}
\frac{\partial E}{\partial W_{ij}} &= -\frac{v_i}{\sigma_i^2}h_j \\
\frac{\partial E}{\partial \mu_i} &= -\frac{v_i - \mu_i}{\sigma_i^2} \\
\frac{\partial E}{\partial b_j} &= -h_j \\
\frac{\partial E}{\partial z_i} &= -\frac{(v_i-\mu_i)^2}{2\sigma_i^2}+\frac{\sum_jv_iW_{ij}h_j}{\sigma_i^2}
\end{align}
この導出は散々見かけるので割愛。と、見せかけて・・・
\begin{align}
\frac{\partial E}{\partial W_{ij}}
&= \frac{\partial}{\partial W_{ij}} -\sum_i\frac{v_i}{\sigma_i^2}U_i \\
&= -\frac{v_i}{\sigma_i^2}\frac{\partial}{\partial W_{ij}}\sum_jW_{ij}h_j \\
&= -\frac{v_i}{\sigma_i^2}h_j
\end{align}
\begin{align}
\frac{\partial E}{\partial \mu_i}
&= \frac{\partial}{\partial \mu_i} \frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2 \\
&= \frac{1}{2} \cdot 2 \frac{v_i - \mu_i}{\sigma_i} \cdot -\frac{1}{\sigma_i} \\
&= \frac{-v_i + \mu_i}{\sigma_i^2} \\
\end{align}
次
\begin{align}
\frac{\partial E}{\partial b_j} &= \frac{\partial}{\partial b_j}-\boldsymbol{b}^\text{T}\boldsymbol{h} \\
&= \frac{\partial}{\partial b_j}-\sum_j b_jh_j \\
&= -h_j
\end{align}
最後
\begin{align}
E(\boldsymbol{v}, \boldsymbol{h}) &=
\frac{1}{2}\sum_i\left(\frac{v_i - \mu_i}{\sigma_i}\right)^2
- \sum_i \frac{v_i}{\sigma_i^2}U_i - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \hspace{10mm} (\ \boldsymbol{U} = \boldsymbol{Wh}\ )\\
&= \frac{1}{2}\sum_i(v_i - \mu_i)^2\exp(-z_i) - \sum_i v_iU_i\exp(-z_i) - \boldsymbol{b}^{\text{T}}\boldsymbol{h} \\
\frac{\partial E}{\partial z_i}
&= \frac{1}{2}(v_i - \mu_i)^2\exp(-z_i) \cdot (-1) - v_iU_i\exp(-z_i)\cdot (-1) \\
&= -\frac{1}{2}\frac{(v_i-\mu_i)^2}{\exp(-z_i)} + \frac{v_iU_i}{\exp(-z_i)} \\
&= -\frac{1}{2}\frac{(v_i -\mu_i)^2}{\sigma_i^2} + \frac{v_iU_i}{\sigma_i^2} \\
&= -\frac{1}{2}\frac{(v_i -\mu_i)^2}{\sigma_i^2} + \frac{v_i\sum_jW_{ij}h_j}{\sigma_i^2}
\end{align}
そもそも・・・
シンプルに、論分に書いてある$\nabla\theta$ を使って
$\theta = \theta + \eta\nabla\theta$ でOK。