2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

離散混合ロジスティック分布(the discretized mixture of logistics)

Posted at

はじめに

定期的に VAE の新しい論文が公開される昨今いかがお過ごしでしょうか
今年も新しい VAE 論文が公開されました。URLは以下のとおりです。

この中で損失として、MoL(the discretized mixture of logistics) 、日本語で「離散混合ロジスティック分布」というものが使われています。
この MoL の説明を行います。

離散ロジスティック分布

離散混合ロジスティック分布というぐらいですから、「離散ロジスティック分布(the discretized logistics)」の混合分布モデルになります。
では、離散ロジスティック分布とは何かという話になります。

おそらく最初に提案されたのは以下の論文だと思います。

一般的な画像の1要素は 0 〜 255 の整数値、つまり離散値です。これを利用して確率密度ではなく、確率質量を求める関数を使います。
画像の要素の値が x のときに、x - 0.5 〜 x + 0.5 の範囲の確率質量をを x の確率質量とします。

このときに確率は、累積分布関数を F とすると以下になります。

P(x) = F(x+0.5) - F(x-0.5) 

ただし、最小値の 0 の場合は下限を、最大値の 255 の場合には上限を設定しません。
そのため、以下のようになります。

\begin{eqnarray} 
P(0)   &=& F(0.5) \\
P(255) &=& 1 - F(255-0.5) 
\end{eqnarray}

一般的にニューラルネットで画像処理を行う場合、入力画像の値として 0 〜 255 の整数をそのまま利用するのではなく、0.0 〜 1.0 もしくは -1.0 〜 1.0 の範囲になるように線形変換します。
その場合は、最大値や最小値、値のステップ幅を調整します。

入力が取りうる値の最小値を l、 最大値を g、入力のある値と次の値の差を d とすると、以下の離散ロジスティック分布の確率質量は以下の式になります。

P(x) =
\begin{cases}
F(x + \frac{1}{2} d)     & x = g \\
1 - F(x - \frac{1}{2} d) & x = l \\
F(x + \frac{1}{2} d) - F(x - \frac{1}{2}d) & \text{それ以外}
\end{cases}

また、入力値 x に対して、確率質量を計算する範囲( $x - \frac{1}{2} d$ 〜 $x + \frac{1}{2} d$) を x の瓶(bin)と呼ぶこととします。

堅牢性の向上

Efficient-VDVAE の実装をみると、尤度 P(x) を素直な sigmoid 関数の差で計算しています。しかしながら、その値が 1e-5 以下の場合に、特殊な分岐と値のクリッピングをしています。

        log_probs = torch.where(broadcast_targets == self.min_pix_value, log_cdf_plus, 
                                torch.where(broadcast_targets == self.max_pix_value, log_one_minus_cdf_min, 
                                            torch.where(cdf_delta > 1e-5, 
                                                        torch.log(torch.clamp(cdf_delta, min=1e-12)), 
                                                        log_pdf_mid - np.log(self.num_classes / 2))))  # B, C, M, H, W 

引用元

確認したところ、PixelCNN++ のソースコードのコメントに意図が書いてありました。

now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)

this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()

log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value

引用元

tf.select() の逆伝播では勾配を選択する代わりに、0との積が使われているため、潜在的に勾配が NaN になる可能性があると記載されています。
だいぶ本筋からは外れますが、私は日常的には pytorch を使っているため、pytorch で同じ問題が発生するかを確認します。

log(0) のような値が無限になるような関数を torch.where を使って値を制限したときに、勾配が NaN になるかを確認します。
確認用のコードは以下の通りです。

import torch 

x = torch.zeros([], requires_grad=True) 
y = torch.log(x) 
y.backward() 

print("[without where()] y = {}, x.grad = {}".format(y.item(), x.grad.item())) 

x = torch.zeros([], requires_grad=True) 
y = torch.where(x > 1e-5, torch.log(x), torch.full_like(x, -20)) 
y.backward() 

print("[with where()] y = {}, x.grad = {}".format(y.item(), x.grad.item())) 

x = torch.zeros([], requires_grad=True) 
y = torch.where(x > 1e-5, torch.log(torch.clamp(x, min=1e-12)), torch.full_like(x, -20)) 
y.backward() 

print("[with where() and clamp()] y = {}, x.grad = {}".format(y.item(), x.grad.item())) 

pytorch のバージョン 1.10.0+cpu で確認しました。
出力結果は以下のとおりです。

$ python nan.py
[without where()] y = -inf, x.grad = inf
[with where()] y = -20.0, x.grad = nan
[with where() and clamp()] y = -20.0, x.grad = 0.0

pytorch でも非選択側が inf の場合、勾配が NaN になるようです。

その上で、尤度が 1e-5 未満の場合は近似値として、x の bin の範囲での確率密度が常に x の確率密度の場合の確率質量を利用します。
確率密度関数を f とすると以下の式になります。

P(x) \sim f(x) * d

上記のコードでは対数尤度の計算をしているため、次のようになっています。

\log{P(x)} \sim \log{f(x)} + \log{d}

対数尤度の式の簡略化

特殊なトリックが必要にならないように、計算式が簡略化できないか考えます。

まず、元の確率質量の計算式を以下に示します。

P(x) = F(x + \frac{1}{2} d) - F(x - \frac{1}{2} d)

ロジスティック分布の累積分布関数は次の通りです。

F(x;\mu, s) = \frac{1}{1 + exp(-(x - \mu)/s)}

入力が x のときの bin の下限を $x_L$ とします。$x_L$ は以下の式で表現できます。

x_L = x - \frac{1}{2}d

入力が x のときの bin 上限を $x_U$ とします。$x_U$ は以下の式で表現できます。

\begin{eqnarray}
x_U &=& x + \frac{1}{2}d \\
    &=& x_L + d/s
\end{eqnarray}

確率質量の式の F(x) にロジスティック分布の累積分布関数を使い、$x_L$ で表した式が以下になります。

P(x) = \frac{1}{1 + exp(-(x_L + d/s))} - \frac{1}{1 + exp(-x_L)} \\

この式を変形していきます。

\begin{eqnarray}
P(x) &=& \frac{1}{1 + exp(-(x_L + d/s))} - \frac{1}{1 + exp(-x_L)} \\
     &=& \frac{ (1 + exp(-x_L)) - (1 + exp(-(x_L + d/s))) } {(1 + exp(-(x_L + d/s)))(1 + exp(-x_L))} \\
     &=& \frac{ exp(-x_L) - exp(-(x_L + d/s)) } {(1 + exp(-(x_L + d/s)))(1 + exp(-x_L))} \\
     &=& \frac{ exp(-x_L) - exp(-x_L)exp(-d/s) } {(1 + exp(-(x_L + d/s)))(1 + exp(-x_L))} \\
     &=& \frac{ exp(-x_L) (1 - exp(-d/s)) } {(1 + exp(-(x_L + d/s)))(1 + exp(-x_L))} \\
     &=& \frac{ 1 - exp(-d/s) } {(1 + exp(-(x_L + d/s)))\frac{1 + exp(-x_L)}{exp(-x_L)}} \\
     &=& \frac{ 1 - exp(-d/s) } {(1 + exp(-(x_L + d/s)))(1 + exp(x_L))}
\end{eqnarray}

損失関数としては対数尤度を利用するため、対数を取ります。

\begin{eqnarray}
\log{P(x)} &=& \log\left\{\frac{ 1 - exp(-d/s) } {(1 + exp(-(x_L + d/s)))(1 + exp(x_L))} \right\} \\
           &=& \log(1 - exp(-d/s)) - \log(1 + exp(-(x_L + d/s))) - \log(1 + exp(x_L))
\end{eqnarray}

$\log(1 + exp(x))$ は softplus 関数であるため、以下のように表現できます。

\log{P(x)} = \log(1 - exp(-d/s)) - softplus(-x_U) - softplus(x_L)

ついでに、x が下限 l の場合の式も変形します。

\begin{eqnarray}
P(x=l) &=& 1 - F(x_U) \\
       &=& 1 - \frac{1}{1 + exp(-x_U)} \\
       &=& \frac{(1 + exp(-x_U)) - 1}{1 + exp(-x_U)} \\
       &=& \frac{exp(-x_U)}{1 + exp(-x_U)} \\
       &=& \frac{1}{1 + exp(x_U)}
\end{eqnarray}

対数をとると

\begin{eqnarray}
\log{P(x=l)} &=& \log \left\{ \frac{1}{1 + exp(x_U)} \right\} \\
             &=& - \log(1 + exp(x_U)) \\
             &=& - softplus(x_U)
\end{eqnarray}

となります。

最終的にまとめると、対数尤度は以下に式になります。

\log{P(x)} =
\begin{cases}
-softplus(-x_L)                          & x = g \\
-softplus(x_U)                           & x = l \\
\log(1 - exp(-d/s)) - softplus(-x_U) - softplus(x_L) & \text{それ以外}
\end{cases}

$\log(1 - exp(-d/s))$ の計算が問題なくできれば問題はなさそうです。

log1mexp

Tensorflow には $\log(1 - exp(-|x|))$ を安定的に計算する関数があります。

pytorch の場合は自分で実装する必要があります。
幸い issue には登録されています。

この issue に安定的な実装方法についての R のドキュメントへのリンクがあります。
このドキュメントによると以下のように実装すれば安定します。

log1mexp(x) =
\begin{cases}
\log(-expm1(-x)) & 0 < x \leq \log{2} \\
log1p(-exp(-x)) & \log{2} < x
\end{cases}

一応、逆伝播も実装するため、微分式を求めてみます。

\begin{eqnarray}
f(x)  &=& \log(1 - exp(-x)) \\
f'(x) &=& \frac{1}{1 - exp(-x)} exp(-x) \\
      &=& \frac{1}{exp(x) - 1}
\end{eqnarray}

コードは以下のとおりです。

import math 
import torch 
 
LOG2 = math.log(2.) 
 
class Log1mexp(torch.autograd.Function): 
 
    @staticmethod 
    def forward(ctx, x): 
        result = torch.where(LOG2 <= x, 
            torch.log(-torch.expm1(-x)), 
            torch.log1p(-torch.exp(-x)) 
        ) 
 
        ctx.save_for_backward(x) 
        return result 
 
    @staticmethod 
    def backward(ctx, grad_output): 
        x, = ctx.saved_tensors 
        return grad_output / torch.expm1(x) 
 
exp1mexp = Log1mexp.apply 

Optimizer について

最後に利用の際の注意点を述べます。

最初に紹介した Efficient-VDVAE の論文では、損失項に MoL を利用する場合、勾配の二次モーメントが1よりもはるかに大きくなる可能性があるため、Adam ではなく Adamax の利用を提案しています。したがって、MoLを利用する場合は Adamax を利用しましょう。

個人的には Adam と比べた場合 Adamax に特段デメリットが見当たらないため、MoL に限らず常に Adamax を利用した方がよいと考えています。

以上

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?