Python
機械学習
PyTorch

微分できない階段関数の最小値をvariational optimizationによって求めてみる

はじめに

微分できる関数であれば、最小値を求めるコードをPyTorchなどで書くのは簡単です。
しかし、微分できない関数の最小値を求めたいときは、特殊な方法を使う必要があります。
この記事では、variational optimization1と呼ばれる手法を使ってみます。

例題

次の関数$f(\theta)$の最小値を、variational optimizationによって求めてみます。

  f(\theta) = \begin{cases}
    0 & (\theta < 0) \\
    1 & (\theta \geq 0)
  \end{cases}

最小値はもちろん0です。ついでに言えば、最大値は1です。

Variational optimization

Variational optimizationは、次の観察に基づいて、関数$f$を最小化する手法です。

\min_\boldsymbol{\theta} f(\boldsymbol{\theta}) \leq \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] = U(\boldsymbol{\psi})

関数$f$の値の最小値は、関数$f$の値の平均をどんなふうにとっても(どんなふうに分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$を選んでも)、
必ずその平均以下になる、ということです。
もちろん、関数$f$の個々の値は、平均より大きくなることはあります。
しかし、その最小値は、関数$f$の値の平均をどんなふうにとっても、その平均以下です。

$f$は微分できないという想定ですので、代わりに上で定義した$U(\boldsymbol{\psi})$を最小化します。
そのため、$U(\boldsymbol{\psi})$の勾配を求めます。なお、$\boldsymbol{\psi}$は、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータです。

$U(\boldsymbol{\psi})$の勾配$\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})$を求める際には、以下のような式変形の結果を使います。

\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& = \nabla_\boldsymbol{\psi} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ] \\
& = \nabla_\boldsymbol{\psi} \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int \nabla_\boldsymbol{\psi} q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \int q(\boldsymbol{\theta} | \boldsymbol{\psi}) \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta})
d\boldsymbol{\theta} \\
& = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta} | \boldsymbol{\psi}) f(\boldsymbol{\theta}) ]
\end{align}

2行目から3行目の変形は自明ではないですが・・・そのあたりの話は割愛します2
最後の期待値はモンテカルロ積分で近似することができます。つまり、
$S$個のサンプル$\boldsymbol{\theta}^{(1)}, \ldots, \boldsymbol{\theta}^{(S)}$を$q(\boldsymbol{\theta} | \boldsymbol{\psi})$から取って、

\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \frac{1}{S} \sum_{s=1}^S \nabla_\boldsymbol{\psi} \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

とすれば近似できます。

なお、この和は有限な和なので、

\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
\approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

と、分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$のパラメータでの偏微分を和の後で求めることもできます。
要注意なのは、この式は決して$U(\boldsymbol{\psi})
\approx \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})$を意味してはいないことです。
勾配がほぼ同じ$\approx$と言っているだけです。$U(\boldsymbol{\psi})$の定義はあくまで$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$です。

例題の解き方

例題の階段関数を、上の議論での$f$とします。
$f$は一変数の関数ですので、確率分布$q(\boldsymbol{\theta} | \boldsymbol{\psi})$は一次元の正規分布$\mathcal{N}(\mu,\sigma)$にします。
正規分布の平均パラメータが$\mu$、標準偏差パラメータが$\sigma$です。

そして、上に示した

\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) f(\boldsymbol{\theta}^{(s)})
\end{align}

この式で求まる勾配を使って、$\boldsymbol{\psi}=(\mu,\sigma)$を更新するコードをPyTorchで書くと、以下のようになります。

import torch
from torch.distributions import Normal

def myfunc(theta):
  if theta < 0.0:
    return 0.0
  else:
    return 1.0

torch.manual_seed(1)

mu = torch.zeros(1, requires_grad=True)
log_sigma = torch.zeros(1, requires_grad=True)

optimizer = torch.optim.SGD([mu, log_sigma], lr=1.)

for i in range(100):
  m = Normal(mu, torch.exp(log_sigma))
  samples = m.sample((10,))
  f = torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).unsqueeze(1)
  out = m.log_prob(samples)
  out = (out * f).mean()
  out.backward()
  optimizer.step()
  optimizer.zero_grad()
  print('{:d} mu: {:.4f} ; sigma: {:.4f}'.format(i+1, mu.item(), torch.exp(log_sigma).item()))

私の趣味(?)で、標準偏差パラメータ$\sigma$は、その対数をとったもの$\log(\sigma)$でパラメータ化しています。
また、正規分布からサンプルを取るところと、対数尤度を求めるところは、
PyTorchのtorch.distributions.Normalsample()メソッドと、log_prob()メソッドを、
それぞれ利用しています。サンプルは10個取っていますが、この個数は適当に決めています。

上記のコードを実行すると

1 mu: -0.1993 ; sigma: 1.4853
2 mu: -0.2457 ; sigma: 1.7693
3 mu: -0.8013 ; sigma: 0.2927
4 mu: -0.8013 ; sigma: 0.2927
5 mu: -0.8013 ; sigma: 0.2927
6 mu: -0.8013 ; sigma: 0.2927
7 mu: -0.8013 ; sigma: 0.2927
8 mu: -0.8013 ; sigma: 0.2927
9 mu: -0.8013 ; sigma: 0.2927
10 mu: -0.8013 ; sigma: 0.2927
#(・・・中略・・・)
36 mu: -0.8013 ; sigma: 0.2927
37 mu: -0.8013 ; sigma: 0.2927
38 mu: -0.8013 ; sigma: 0.2927
39 mu: -0.8013 ; sigma: 0.2927
40 mu: -0.8013 ; sigma: 0.2927
41 mu: -0.8013 ; sigma: 0.2927
42 mu: -1.9106 ; sigma: 0.1127
43 mu: -1.9106 ; sigma: 0.1127
44 mu: -1.9106 ; sigma: 0.1127
45 mu: -1.9106 ; sigma: 0.1127
#(・・・中略・・・)
96 mu: -1.9106 ; sigma: 0.1127
97 mu: -1.9106 ; sigma: 0.1127
98 mu: -1.9106 ; sigma: 0.1127
99 mu: -1.9106 ; sigma: 0.1127
100 mu: -1.9106 ; sigma: 0.1127

となり、途中から$\mu=-1.9106, \sigma=0.1127$で動かなくなっています。

得られた$\mu, \sigma$を使って、$U(\boldsymbol{\psi}) = \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} | \boldsymbol{\psi})} [ f(\boldsymbol{\theta}) ]$に従って$U(\boldsymbol{\psi})$を計算してみます。上のコードの続きで

m = Normal(mu, torch.exp(log_sigma))
samples = m.sample((1000,))
print(torch.tensor(tuple(map(myfunc, torch.unbind(samples)))).mean().item())

と実行します。これはモンテカルロ積分による近似です。すると、

0.0

と表示されます。
最小化したい関数$f$の最小値は0でした。$U(\boldsymbol{\psi})$の値は上のように0ですので、正解が得られています。
もちろん、$\mu=-1.9106, \sigma=0.1127$でも、正規分布$\mathcal{N}(\mu,\sigma)$からのサンプル$\theta$が0以上になり、
よって$f(\theta)$が1になる確率はゼロではないですが、ほぼゼロなので、上の計算は0.0になっています。

おわりに

今回は例題が簡単だったのであっさり解けてしまいましたが、もっと難しい問題の場合は、
論文1の式(13)にあるbaselineと呼ばれる定数$b$を使う必要が出てくると思います。
その場合は、定数$b$を計算しておいて、これを

\begin{align}
\nabla_\boldsymbol{\psi} U(\boldsymbol{\psi})
& \approx \nabla_\boldsymbol{\psi} \frac{1}{S} \sum_{s=1}^S \log q(\boldsymbol{\theta}^{(s)} | \boldsymbol{\psi}) (f(\boldsymbol{\theta}^{(s)}) - b)
\end{align}

というように使うことになるのだと思います。


  1. Gilles Louppe, Joeri Hermans, and Kyle Cranmer. Adversarial Variational Optimization of Non-Differentiable Simulators. https://arxiv.org/abs/1707.07113 ←この論文の3.2節を参考にしました。この論文は他の内容も含んでいます。今回は、variational optimizationの説明の部分だけを参考にしました。 

  2. Joe Staines and David Barber. Variational Optimization. https://arxiv.org/abs/1212.4507 のSection 1.1参照。