DeepLearning
statistics
Chainer

Chainerで確率分布が扱えるようになりました。chainer.distributionsの紹介

Chainer v5 から導入された chainer.distributions

Chainerに確率分布を扱う機能、Distributionがv5から追加されました。これを使うことによって、VAEなどの生成モデルなどが書きやすくなります。

Distributionでは、主に以下の計算が微分可能な形でできます。

  • 対数確率の計算
  • サンプリング(一部微分不可能な分布あり)
  • 分布同士のKL Divergence

このような機能があると、どのように便利なのでしょうか。

機械学習には確率分布が使われている

機械学習で2値分類問題を考えましょう。これは、入力$x_i$に対応する出力$t_i$が1か0かを判定する問題です。
まず$\theta$をパラメーターとして持つ$f$というモデルを考えます。

y_i = f_{\theta}(x_i)

その時に、通常、以下の形の誤差関数を最小化するように$\theta$を最適化します。

L(x_{1...N}) = -\sum_i{t_i\log{y_i} + (1-t_i)\log(1-y_i)}\tag{1}

なぜ、このような関数を最適化すると良いのでしょうか?

確率という観点で見ると、実はこれはベルヌーイ分布という分布の最尤推定であると解釈することができます。

P(t; p) = \begin{cases}
    p & (t =1) \\
    (1-p) & (t=0)
  \end{cases}

確率$p$で$t=1$になって、そうじゃない時は$t=0$になるという簡単な分布なのですが、機械学習では、このパラメーター$p$を推定しています。

p_i = y_i = f_{\theta}(x_i)

この推定したパラメーターを元に$t_{1...N}$の確率を書くと、$p_{1...N}$の尤度というものになります。

P(t_{1...N}; p_{1...N}) = \prod_i{P(t_i; p_i)} = \prod_i \begin{cases}
    p_i & (t_i =1) \\
    (1-p_i) & (t_i=0)
  \end{cases}

この値を最大化するわけですが、この値は小さくなりすぎる場合があるので、コンピューターで計算する上では、対数をとって、マイナスを付けて、ロス関数とします。

-\log{P(t_{1...N}; p_{1...N})} = -\sum_i{\log{P(t_i; p_i)}} = -\sum_i{ \begin{cases}
    \log{p_i} & (t_i =1) \\
    \log{(1-p_i)} & (t_i=0)
  \end{cases}} \\
= -\sum_i{t_i\log{p_i} + (1-t_i)\log(1-p_i)}

となり、(1)式と同じものが出てきます。

このように機械学習では、確率分布が密接に関わっています。分類のロス関数だけでなく、正則化やVAEなども確率という観点から解釈することができます。ただし、実際に機械学習を行う時に、毎回手計算で、上のような確率操作を行うのはちょっと大変だと思います。

chainer.distributionsでの最尤法

そこでchainer.distributionsでは、このような対数確率の計算などを簡単な形で行えるようになっています。

import chainer.distributions as D

p = f(x)
loss = F.sum(D.Bernoulli(p).log_prob(t))

実際には、上のコードは不安定で、ベルヌーイ分布のパラメーターとしてはlogitを使うことが多いです。

loss = F.sum(D.Bernoulli(logit=f(x)).log_prob(t))

実際にはこれはF.sigmoid_cross_entropyを使えば良いわけですが、例えば、ポアソン回帰等、他の分布を使う時に便利になります。また、確率的な観点から直感的な書き方ができるようになります。

VAEの例

実際に直感的な書き方ができる例として、VAEがあります。
すでに、ChainerのVAEのExamplechainer.distributionsが使われています。

モデルについて

VAEは観測値$x$、潜在変数$z$が登場する確率モデルで、$x$をサンプリングする用途で使われます。
VAEではEncoderとDecoderを確率分布として定義します。

q_{\phi}(z|x) \\
p_{\theta}(x|z)

また、潜在変数の確率分布(Prior)を正規分布として定義します。(必ずしも正規分布である必要はありません)

p(z) = N(0, 1)

こうすることにより、$\theta$を最適化することにより、サンプリングすることができます。
$\theta$の最適化は、以下の変分下限を最大化します(詳細は論文、あるいは他の解説記事を参照してください)

\mathcal{L}(\theta; \phi; x) = E_{q_{\phi}(z|x)}[-log{q_{\phi}(z|x)}+\log{p_{\theta}(x,z)}] \\
 = E_{q_{\phi}(z|x)}[-log{q_{\phi}(z|x)}+\log{p_{\theta}(x|z)}+\log{p(z)}] \\
= - KL(q_{\phi}(z|x)|p(z)) + E_{q_{\phi}(z|x)}[\log{p_{\theta}(x|z)}]

期待値計算はモンテカルロ積分で行います。

サンプルコード

サンプルコードではEncoder, Decoder, PriorをそれぞれDistributionのインスタンスを返すLinkChainで表しています。

class Encoder(chainer.Chain):

    def __init__(self, n_in, n_latent, n_h):
        super(Encoder, self).__init__()
        ...

    def forward(self, x):
        h = F.tanh(self.linear(x))
        mu = self.mu(h)
        ln_sigma = self.ln_sigma(h)  # log(sigma)
        return D.Normal(loc=mu, log_scale=ln_sigma)
class Decoder(chainer.Chain):

    def __init__(self, n_in, n_latent, n_h, binary_check=False):
        super(Decoder, self).__init__()
        ...

    def forward(self, z, inference=False):
        n_batch_axes = 1 if inference else 2
        h = F.tanh(self.linear(z, n_batch_axes=n_batch_axes))
        h = self.output(h, n_batch_axes=n_batch_axes)
        return D.Bernoulli(logit=h, binary_check=self.binary_check)
class Prior(chainer.Link):

    def __init__(self, n_latent):
        super(Prior, self).__init__()
                ...

    def forward(self):
        return D.Normal(self.loc, scale=self.scale)

そしてこれをAvgELBOLossというクラスに入れることで、VAEのロスが計算できるようになっています。

class AvgELBOLoss(chainer.Chain):
    ...

    def __init__(self, encoder, decoder, prior, beta=1.0, k=1):
        super(AvgELBOLoss, self).__init__()
        ...

    def __call__(self, x):
        q_z = self.encoder(x)
        z = q_z.sample(self.k)
        p_x = self.decoder(z)
        p_z = self.prior()

        reconstr = F.mean(F.sum(p_x.log_prob(
            F.broadcast_to(x[None, :], (self.k,) + x.shape)), axis=-1))
        kl_penalty = F.mean(F.sum(chainer.kl_divergence(q_z, p_z), axis=-1))
        loss = - (reconstr - self.beta * kl_penalty)
        ...
        return loss