16
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

混合ガウス分布の変分ベイズ法による推定

Last updated at Posted at 2017-05-13

図示

変分混合ガウス分布.gif

初期クラス数を6に設定しても、最終的に3つのクラスに収束していることがわかる

アルゴリズム

  1. $r_{nk}$を初期化
  2. 三つの統計量を計算
    • $N_k = \sum_{n=1}^N r_{nk}$
    • ${\bar x_k} = \frac{1}{N_k} \sum_{n=1}^N r_{nk}x_n$
    • $S_k = \frac{1}{N_k} \sum_{n=1}^N r_{nk}(x_n - {\bar x_k})(x_n - {\bar x_k})^T$
  3. Mstep: $q(\pi) = Dir(\pi|\alpha), q(\mu_k, \Lambda_k) = N(\mu_k|m_k, (\beta_k \Lambda_k)^{-1})W(\Lambda_k|W_k, \nu_k)$を求める。
    • $\alpha_k = \alpha_0 + N_k$
    • $\beta_k = \beta_0 + N_k$
    • $m_k = \frac{1}{\beta_k}(\beta_0 m_0 + N_k {\bar x_k})$
    • $W_k^{-1} = W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}({\bar x_k} - m_0)({\bar x_k} - m_0)^T$
    • $\nu_k = \nu_0 + N_k$
  4. Estep: $q(Z) = \Pi_{n=1}^N \Pi_{k=1}^K r_{nk}^{z_{nk}}$を求める
    • $r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^K \rho_{nj}}$
    • $\ln\rho_{nk} = E[\ln\pi_k] + \frac{1}{2} E[\ln |\Lambda_k|] - \frac{D}{2}\ln(2\pi) - \frac{1}{2}E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)]$
    • $E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)] = D\beta_k^{-1} + \nu_k(x_n - m_k)^TW_k(x_n - m_k)$
    • $E[\ln|\Lambda_k|] = \sum_{i=1}^D \psi(\frac{\nu_k + 1 - i}{2}) + D \ln 2 + \ln|W_k|$
    • $E[\ln \pi_k] = \psi(\alpha_k) - \psi({\hat \alpha})$
    • ${\hat \alpha} = \sum_k \alpha_k$
  5. 収束するまで2~4をまわす

実装

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")

#次元
D = 2
#データ数
N = 2000


# 実際の値
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)

plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1), 
                    np.random.multivariate_normal(mu2, sigma2, N2),
                    np.random.multivariate_normal(mu3, sigma3, N3)
                   ])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()

w8dO9h9H66IZwAAAABJRU5ErkJggg==.png

#初期クラス数
K = 6
#初期値
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])
 
#事前分布のパラメータ
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)
 
#初期パラメータ
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))
 
def multi_gauss(x, y, mu, sigma):
    return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))
 
#1: r_nkの初期化
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
    g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
    r[:, k] = g[:, k] / g.sum(1)

# 図示
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
    Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
    plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
 
 
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png")
plt.show()

n9OnT9Pc3KzKdIpEGtT0IVbdmfJhbduhUGhT+oCIjkz3A8h8XxDVoLYXa+NDT09P2nyQK1AkEokEDY4MJRKJZDOQwVAikUiQwVAikUgAGQwlEokEkMFQIpFIABkMJRKJBJDBUCKRSAD4HwgVxm0OO01kAAAAAElFTkSuQmCC.png

for i in range(20):
    #2: 三つの統計量を計算
    N_k = r.sum(0)
    mu = r.T.dot(data) / np.c_[N_k]
    for k in range(K):
        S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]
 
    #3: Mstep
    alpha = alpha_0 + N_k
    beta = beta_0 + N_k
    m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
    for k in range(K):
        tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
        tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
        W_k[k] = LA.inv(tmp2)
    nu_k = nu_0 + N_k
 
    #4: Estep
    E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
    E_ln_pi = digamma(alpha) - digamma(alpha.sum())
    for k in range(K):
        E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
    ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
    r = ro / np.c_[ro.sum(1)]
    r[r < 1e-10] = 1e-10
 
    # gifの図を作成
    plt.figure(figsize=(5, 5))
    X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
    pi = np.exp(E_ln_pi)
    for k in range(K):
        Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
        if np.exp(E_ln_pi)[k] > 0.01:
            plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
    plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
    plt.xlim(-2.1, 2.1)
    plt.ylim(-2.1, 2.1)
    title = "iter: {}".format(i+1)
    plt.title(title)
    #plt.savefig("data/" + title + ".png")
    plt.show()

16
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?