はじめに
微分できる関数であれば、最小値を求めるコードをPyTorchなどで書くのは簡単です。
しかし、微分できない関数の最小値を求めたいときは、特殊な方法を使う必要があります。
この記事では、variational optimization1と呼ばれる手法を使ってみます。
例題
次の関数$f(\theta)$の最小値を、variational optimizationによって求めてみます。
f(\theta) = \begin{cases}
0 & (\theta < 0) \\
1 & (\theta \geq 0)
\end{cases}
最小値はもちろん0です。ついでに言えば、最大値は1です。
Variational optimization
Variational optimizationは、次の観察に基づいて、関数$f$を最小化する手法です。
\min_\boldsymbol{\theta} f(\boldsymbol{\theta}) \leq \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] = U(\boldsymbol{\psi})
関数$f$の値の最小値は、関数$f$の値の平均をどんなふうにとっても(どんなふうに分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$を選んでも)、
必ずその平均以下になる、ということです。
もちろん、関数$f$の個々の値は、平均より大きくなることはあります。
しかし、その最小値は、関数$f$の値の平均をどんなふうにとっても、その平均以下です。
$f$は微分できないという想定ですので、代わりに上で定義した$U(\boldsymbol{\psi})$を最小化します。
そのため、$U(\boldsymbol{\psi})$の勾配を求めます。なお、$\boldsymbol{\psi}$は、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータです。
$U(\boldsymbol{\psi})$の勾配$\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})$を求める際には、以下のような式変形の結果を使います。
\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& = \nabla_\boldsymbol{\psi} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] \\
& = \nabla_\boldsymbol{\psi} \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int \nabla_\boldsymbol{\psi} q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta}) ]
\end{align}
2行目から3行目の変形は自明ではないですが・・・そのあたりの話は割愛します2。
最後の期待値はモンテカルロ積分で近似することができます。つまり、
$S$個のサンプル$\boldsymbol{\theta}^{(1)}, \ldots, \boldsymbol{\theta}^{(S)}$を$q(\boldsymbol{\theta} | \boldsymbol{\psi})$から取って、
\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \frac{1}{S} \sum_{s=1}^S \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}
とすれば近似できます。
なお、この和は有限な和なので、
\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}
と、分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータでの偏微分を和の後で求めることもできます。
要注意なのは、この式は決して$U(\boldsymbol{\psi})
\approx \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})$を意味してはいないことです。
勾配がほぼ同じ$\approx$と言っているだけです。$U(\boldsymbol{\psi})$の定義はあくまで$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$です。
例題の解き方
例題の階段関数を、上の議論での$f$とします。
$f$は一変数の関数ですので、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$は一次元の正規分布$\mathcal{N}(\mu,\sigma)$にします。
正規分布の平均パラメータが$\mu$、標準偏差パラメータが$\sigma$です。
そして、上に示した
\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}
この式で求まる勾配を使って、$\boldsymbol{\psi}=(\mu,\sigma)$を更新するコードをPyTorchで書くと、以下のようになります。
import torch
from torch.distributions import Normal
def myfunc(theta):
if theta < 0.0:
return 0.0
else:
return 1.0
torch.manual_seed(1)
mu = torch.zeros(1, requires_grad=True)
log_sigma = torch.zeros(1, requires_grad=True)
optimizer = torch.optim.SGD([mu, log_sigma], lr=1.)
for i in range(100):
m = Normal(mu, torch.exp(log_sigma))
samples = m.sample((10,))
f = torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).unsqueeze(1)
out = m.log_prob(samples)
out = (out * f).mean()
out.backward()
optimizer.step()
optimizer.zero_grad()
print('{:d} mu: {:.4f} ; sigma: {:.4f}'.format(i+1, mu.item(), torch.exp(log_sigma).item()))
私の趣味(?)で、標準偏差パラメータ$\sigma$は、その対数をとったもの$\log(\sigma)$でパラメータ化しています。
また、正規分布からサンプルを取るところと、対数尤度を求めるところは、
PyTorchのtorch.distributions.Normal
のsample()
メソッドと、log_prob()
メソッドを、
それぞれ利用しています。サンプルは10個取っていますが、この個数は適当に決めています。
上記のコードを実行すると
1 mu: -0.1993 ; sigma: 1.4853
2 mu: -0.2457 ; sigma: 1.7693
3 mu: -0.8013 ; sigma: 0.2927
4 mu: -0.8013 ; sigma: 0.2927
5 mu: -0.8013 ; sigma: 0.2927
6 mu: -0.8013 ; sigma: 0.2927
7 mu: -0.8013 ; sigma: 0.2927
8 mu: -0.8013 ; sigma: 0.2927
9 mu: -0.8013 ; sigma: 0.2927
10 mu: -0.8013 ; sigma: 0.2927
# (・・・中略・・・)
36 mu: -0.8013 ; sigma: 0.2927
37 mu: -0.8013 ; sigma: 0.2927
38 mu: -0.8013 ; sigma: 0.2927
39 mu: -0.8013 ; sigma: 0.2927
40 mu: -0.8013 ; sigma: 0.2927
41 mu: -0.8013 ; sigma: 0.2927
42 mu: -1.9106 ; sigma: 0.1127
43 mu: -1.9106 ; sigma: 0.1127
44 mu: -1.9106 ; sigma: 0.1127
45 mu: -1.9106 ; sigma: 0.1127
# (・・・中略・・・)
96 mu: -1.9106 ; sigma: 0.1127
97 mu: -1.9106 ; sigma: 0.1127
98 mu: -1.9106 ; sigma: 0.1127
99 mu: -1.9106 ; sigma: 0.1127
100 mu: -1.9106 ; sigma: 0.1127
となり、途中から$\mu=-1.9106, \sigma=0.1127$で動かなくなっています。
得られた$\mu, \sigma$を使って、$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$に従って$U(\boldsymbol{\psi})$を計算してみます。上のコードの続きで
m = Normal(mu, torch.exp(log_sigma))
samples = m.sample((1000,))
print(torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).mean().item())
と実行します。これはモンテカルロ積分による近似です。すると、
0.0
と表示されます。
最小化したい関数$f$の最小値は0でした。$U(\boldsymbol{\psi})$の値は上のように0ですので、正解が得られています。
もちろん、$\mu=-1.9106, \sigma=0.1127$でも、正規分布$\mathcal{N}(\mu,\sigma)$からのサンプル$\theta$が0以上になり、
よって$f(\theta)$が1になる確率はゼロではないですが、ほぼゼロなので、上の計算は0.0になっています。
おわりに
今回は例題が簡単だったのであっさり解けてしまいましたが、もっと難しい問題の場合は、
論文1の式(13)にあるbaselineと呼ばれる定数$b$を使う必要が出てくると思います。
その場合は、定数$b$を計算しておいて、これを
\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) (f(\boldsymbol{\theta}^{(s)}) - b)
\end{align}
というように使うことになるのだと思います。
-
Gilles Louppe, Joeri Hermans, and Kyle Cranmer. Adversarial Variational Optimization of Non-Differentiable Simulators. https://arxiv.org/abs/1707.07113 ←この論文の3.2節を参考にしました。この論文は他の内容も含んでいます。今回は、variational optimizationの説明の部分だけを参考にしました。 ↩ ↩2
-
Joe Staines and David Barber. Variational Optimization. https://arxiv.org/abs/1212.4507 のSection 1.1参照。 ↩