58
49

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

混合ガウス分布とlogsumexp

Last updated at Posted at 2015-06-27

#概要
混合ガウス分布は複数のガウス分布の重み付き加算によるモデルで,多峰性の分布を表すことが出来るものです.また,このモデルのパラメータの最適化にはEMアルゴリズムを用いることが出来ます.このEMアルゴリズムのE-stepで潜在変数の期待値を計算しますが,その際にナイーブに計算を行うとoverflow, underflowの問題が起こることがあります.そのとき,logsumexpという有名な数値計算方法を用いることでこの問題を回避することができます.

#混合ガウス分布の概要とlogsumexp
混合ガウス分布のEMアルゴリズムによるパラメータ最適化に関して,logsumexpに関係する部分だけ簡単に示します.より詳しくは『パターン認識と機械学習』などを参照してください.(一応,SlideShareに輪講したときの資料を上げておきます PRML 9.0-9.2)

\begin{aligned}
p(x|\theta) &= \sum_k \pi_k N(x_n|\mu_k,\Sigma_k)\\
\pi_k&: 混合比
\end{aligned}

混合ガウス分布の観測変数$x$の分布は上の式で表されます.
EMアルゴリズムを適用するために,潜在変数$z$を導入し,$x,z$の同時分布を次のように表します.

\begin{aligned}
p(x,z) &= \prod_k \pi_k^{z_k} N(x_n|\mu_k,\Sigma_k)^{z_k}\\
z&: 1ofK
\end{aligned}

これらを用いて潜在変数zの期待値をベイズの定理より求めると,以下のようになります.

\begin{aligned}
\gamma(z_{nk})
&= E[z_{nk}] \\ 
&=\frac{\pi_k N(x_n|\mu_k,\Sigma_k)}{\sum_{k^\prime}\pi_{k^\prime} N(x_n|\mu_{k^\prime},\Sigma_{k^\prime})}
\end{aligned}

混合ガウスのEMアルゴリズムによる最適化でのE-stepでは,この潜在変数の期待値を計算します.分母を見るとガウス分布のsum演算があるので,log scaleで演算を行う場合でもガウス分布の指数演算でunderflowが起こりうることが分かります.そこで何らかの工夫をする必要がありますが,その一つの方法としてlogsumexpがあります.

#logsumexp

\log(\sum^N_{i=1} \exp(x_i))

という計算をしたいときに,この計算の結果自体はoverflow, underflowしない場合でも,個別の$\exp(x_i)$はoverflow, underflowする可能性があります.そこで,以下のように式変形を行います.

\begin{aligned}
\log(\sum^N_{i=1} \exp(x_i)) 
&= \log\{\exp(x_{max})\sum^N_{i=1} \exp(x_i - x_{max})\} \\
& = \log\{\sum^N_{i=1} \exp(x_i - x_{max})\}  + x_{max}
\end{aligned}

このように計算をすることにより指数演算の引数が$[x_{min} - x_{max},0]$の範囲に収まるので,overflow, underflowが起こる可能性が大幅に下がります.

このlogsumexpはPythonの機械学習ライブラリscikitlearnにも実装されており,簡単に利用することが出来ます(Scikit Learn Utilities for Developers).また,scikit learnでのlogsumexpの実装はsklearn/utils/extmath.pyにあり,以下のようになっています.axisに対応している部分以外は上に示した数式と全く同じになっていることが分かります.


def logsumexp(arr, axis=0):
    """Computes the sum of arr assuming arr is in the log domain.
    Returns log(sum(exp(arr))) while minimizing the possibility of
    over/underflow.
    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.utils.extmath import logsumexp
    >>> a = np.arange(10)
    >>> np.log(np.sum(np.exp(a)))
    9.4586297444267107
    >>> logsumexp(a)
    9.4586297444267107
    """
    arr = np.rollaxis(arr, axis)
    # Use the max to normalize, as with the log this is what accumulates
    # the less errors
    vmax = arr.max(axis=0)
    out = np.log(np.sum(np.exp(arr - vmax), axis=0))
    out += vmax
    return out

#混合ガウスモデルの実装
以上を踏まえて混合ガウスモデルの実装したので置いておきます(https://github.com/seataK/machine-learning/blob/master/GaussianMixture/GaussianMixture.py).

import numpy as np
import random
import pylab as plt
from sklearn.utils import extmath
from sklearn.cluster import KMeans
import sys


na = np.newaxis


class DataFormatter:
    def __init__(self, X):
        self.mean = np.mean(X, axis=0)
        self.std = np.std(X, axis=0)

    def standarize(self, X):
        return (X - self.mean[na, :]) / self.std[na, :]


def log_gaussian(X, mu, cov):
    d = X.shape[1]
    det_sig = np.linalg.det(cov)
    A = 1.0 / (2*np.pi)**(d/2.0) * 1.0 / det_sig**(0.5)
    x_mu = X - mu[na, :]
    inv_cov = np.linalg.inv(cov)
    ex = - 0.5 * np.sum(x_mu[:, :, na] * inv_cov[na, :, :] *
                        x_mu[:, na, :], axis=(1, 2))
    return np.log(A) + ex


class GMM:
    def __init__(self,
                 K=2,
                 max_iter=300,
                 diag=False):
        self.K = K
        self.data_form = None
        self.pi = None
        self.mean = None
        self.cov = None
        self.max_iter = max_iter
        self.diag = diag

    def fit(self, _X):
        self.data_form = DataFormatter(_X)
        X = self.data_form.standarize(_X)
        N = X.shape[0]
        D = X.shape[1]
        K = self.K

        # init parameters using K-means
        kmeans = KMeans(n_clusters=self.K)

        kmeans.fit(X)

        self.mean = kmeans.cluster_centers_

        self.cov = np.array([[[1 if i == j else 0
                             for i in range(D)]
                             for j in range(D)]
                             for k in range(K)])

        self.pi = np.ones(K) / K

        # Optimization
        for _ in range(self.max_iter):
            # E-step

            gam_nk = self._gam(X)

            # M-step
            Nk = np.sum(gam_nk, axis=0)

            self.pi = Nk / N

            self.mean = np.sum(gam_nk[:, :, na] * X[:, na, :],
                               axis=0) / Nk[:, na]

            x_mu_nkd = X[:, na, :] - self.mean[na, :, :]

            self.cov = np.sum(gam_nk[:, :, na, na] *
                              x_mu_nkd[:, :, :, na] *
                              x_mu_nkd[:, :, na, :],
                              axis=0) / Nk[:, na, na]

            if(self.diag):
                for k in range(K):
                    var = np.diag(self.cov[k])
                    self.cov[k] = np.array([[var[i] if i == j else 0
                                             for i in range(D)]
                                            for j in range(D)])

    def _gam(self, X):
        log_gs_nk = np.array([log_gaussian(X, self.mean[i], self.cov[i])
                              for i in range(self.K)]).T

        log_pi_gs_nk = np.log(self.pi)[na, :] + log_gs_nk

        log_gam_nk = log_pi_gs_nk[:, :] - extmath.logsumexp(log_pi_gs_nk, axis=1)[:, na]

        return np.exp(log_gam_nk)

    def predict(self, _X):
        X = self.data_form.standarize(_X)

        gam_nk = self._gam(X)

        return np.argmax(gam_nk, axis=1)

#その他
logsumexpではlog,expの演算が多く出てくるので計算速度が遅くなるようです(logsumexp は人類の黒歴史).速度がボトルネックになる場合は他の方法を使ったほうが良いのかもしれません.

58
49
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
58
49

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?