Posted at

微分できない階段関数の最小値をvariational optimizationによって求めてみる


はじめに

微分できる関数であれば、最小値を求めるコードをPyTorchなどで書くのは簡単です。

しかし、微分できない関数の最小値を求めたいときは、特殊な方法を使う必要があります。

この記事では、variational optimization1と呼ばれる手法を使ってみます。


例題

次の関数$f(\theta)$の最小値を、variational optimizationによって求めてみます。

  f(\theta) = \begin{cases}

0 & (\theta < 0) \\
1 & (\theta \geq 0)
\end{cases}

最小値はもちろん0です。ついでに言えば、最大値は1です。


Variational optimization

Variational optimizationは、次の観察に基づいて、関数$f$を最小化する手法です。

\min_\boldsymbol{\theta} f(\boldsymbol{\theta}) \leq \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] = U(\boldsymbol{\psi})

関数$f$の値の最小値は、関数$f$の値の平均をどんなふうにとっても(どんなふうに分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$を選んでも)、

必ずその平均以下になる、ということです。

もちろん、関数$f$の個々の値は、平均より大きくなることはあります。

しかし、その最小値は、関数$f$の値の平均をどんなふうにとっても、その平均以下です。

$f$は微分できないという想定ですので、代わりに上で定義した$U(\boldsymbol{\psi})$を最小化します。

そのため、$U(\boldsymbol{\psi})$の勾配を求めます。なお、$\boldsymbol{\psi}$は、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータです。

$U(\boldsymbol{\psi})$の勾配$\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})$を求める際には、以下のような式変形の結果を使います。

\begin{align}

\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& = \nabla_\boldsymbol{\psi} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] \\
& = \nabla_\boldsymbol{\psi} \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int \nabla_\boldsymbol{\psi} q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta}) ]
\end{align}

2行目から3行目の変形は自明ではないですが・・・そのあたりの話は割愛します2

最後の期待値はモンテカルロ積分で近似することができます。つまり、

$S$個のサンプル$\boldsymbol{\theta}^{(1)}, \ldots, \boldsymbol{\theta}^{(S)}$を$q(\boldsymbol{\theta} | \boldsymbol{\psi})$から取って、

\begin{align}

\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \frac{1}{S} \sum_{s=1}^S \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

とすれば近似できます。

なお、この和は有限な和なので、

\begin{align}

\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

と、分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータでの偏微分を和の後で求めることもできます。

要注意なのは、この式は決して$U(\boldsymbol{\psi})

\approx \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})$を意味してはいないことです。

勾配がほぼ同じ$\approx$と言っているだけです。$U(\boldsymbol{\psi})$の定義はあくまで$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$です。


例題の解き方

例題の階段関数を、上の議論での$f$とします。

$f$は一変数の関数ですので、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$は一次元の正規分布$\mathcal{N}(\mu,\sigma)$にします。

正規分布の平均パラメータが$\mu$、標準偏差パラメータが$\sigma$です。

そして、上に示した

\begin{align}

\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

この式で求まる勾配を使って、$\boldsymbol{\psi}=(\mu,\sigma)$を更新するコードをPyTorchで書くと、以下のようになります。

import torch

from torch.distributions import Normal

def myfunc(theta):
if theta < 0.0:
return 0.0
else:
return 1.0

torch.manual_seed(1)

mu = torch.zeros(1, requires_grad=True)
log_sigma = torch.zeros(1, requires_grad=True)

optimizer = torch.optim.SGD([mu, log_sigma], lr=1.)

for i in range(100):
m = Normal(mu, torch.exp(log_sigma))
samples = m.sample((10,))
f = torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).unsqueeze(1)
out = m.log_prob(samples)
out = (out * f).mean()
out.backward()
optimizer.step()
optimizer.zero_grad()
print('{:d} mu: {:.4f} ; sigma: {:.4f}'.format(i+1, mu.item(), torch.exp(log_sigma).item()))

私の趣味(?)で、標準偏差パラメータ$\sigma$は、その対数をとったもの$\log(\sigma)$でパラメータ化しています。

また、正規分布からサンプルを取るところと、対数尤度を求めるところは、

PyTorchのtorch.distributions.Normalsample()メソッドと、log_prob()メソッドを、

それぞれ利用しています。サンプルは10個取っていますが、この個数は適当に決めています。

上記のコードを実行すると

1 mu: -0.1993 ; sigma: 1.4853

2 mu: -0.2457 ; sigma: 1.7693
3 mu: -0.8013 ; sigma: 0.2927
4 mu: -0.8013 ; sigma: 0.2927
5 mu: -0.8013 ; sigma: 0.2927
6 mu: -0.8013 ; sigma: 0.2927
7 mu: -0.8013 ; sigma: 0.2927
8 mu: -0.8013 ; sigma: 0.2927
9 mu: -0.8013 ; sigma: 0.2927
10 mu: -0.8013 ; sigma: 0.2927
#(・・・中略・・・)
36 mu: -0.8013 ; sigma: 0.2927
37 mu: -0.8013 ; sigma: 0.2927
38 mu: -0.8013 ; sigma: 0.2927
39 mu: -0.8013 ; sigma: 0.2927
40 mu: -0.8013 ; sigma: 0.2927
41 mu: -0.8013 ; sigma: 0.2927
42 mu: -1.9106 ; sigma: 0.1127
43 mu: -1.9106 ; sigma: 0.1127
44 mu: -1.9106 ; sigma: 0.1127
45 mu: -1.9106 ; sigma: 0.1127
#(・・・中略・・・)
96 mu: -1.9106 ; sigma: 0.1127
97 mu: -1.9106 ; sigma: 0.1127
98 mu: -1.9106 ; sigma: 0.1127
99 mu: -1.9106 ; sigma: 0.1127
100 mu: -1.9106 ; sigma: 0.1127

となり、途中から$\mu=-1.9106, \sigma=0.1127$で動かなくなっています。

得られた$\mu, \sigma$を使って、$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$に従って$U(\boldsymbol{\psi})$を計算してみます。上のコードの続きで

m = Normal(mu, torch.exp(log_sigma))

samples = m.sample((1000,))
print(torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).mean().item())

と実行します。これはモンテカルロ積分による近似です。すると、

0.0

と表示されます。

最小化したい関数$f$の最小値は0でした。$U(\boldsymbol{\psi})$の値は上のように0ですので、正解が得られています。

もちろん、$\mu=-1.9106, \sigma=0.1127$でも、正規分布$\mathcal{N}(\mu,\sigma)$からのサンプル$\theta$が0以上になり、

よって$f(\theta)$が1になる確率はゼロではないですが、ほぼゼロなので、上の計算は0.0になっています。


おわりに

今回は例題が簡単だったのであっさり解けてしまいましたが、もっと難しい問題の場合は、

論文1の式(13)にあるbaselineと呼ばれる定数$b$を使う必要が出てくると思います。

その場合は、定数$b$を計算しておいて、これを

\begin{align}

\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) (f(\boldsymbol{\theta}^{(s)}) - b)
\end{align}

というように使うことになるのだと思います。





  1. Gilles Louppe, Joeri Hermans, and Kyle Cranmer. Adversarial Variational Optimization of Non-Differentiable Simulators. https://arxiv.org/abs/1707.07113 ←この論文の3.2節を参考にしました。この論文は他の内容も含んでいます。今回は、variational optimizationの説明の部分だけを参考にしました。 



  2. Joe Staines and David Barber. Variational Optimization. https://arxiv.org/abs/1212.4507 のSection 1.1参照。