Edited at

混合ガウス分布の変分ベイズ法による推定

More than 1 year has passed since last update.


図示

変分混合ガウス分布.gif

初期クラス数を6に設定しても、最終的に3つのクラスに収束していることがわかる


アルゴリズム


  1. $r_{nk}$を初期化

  2. 三つの統計量を計算


    • $N_k = \sum_{n=1}^N r_{nk}$

    • ${\bar x_k} = \frac{1}{N_k} \sum_{n=1}^N r_{nk}x_n$

    • $S_k = \frac{1}{N_k} \sum_{n=1}^N r_{nk}(x_n - {\bar x_k})(x_n - {\bar x_k})^T$



  3. Mstep: $q(\pi) = Dir(\pi|\alpha), q(\mu_k, \Lambda_k) = N(\mu_k|m_k, (\beta_k \Lambda_k)^{-1})W(\Lambda_k|W_k, \nu_k)$を求める。


    • $\alpha_k = \alpha_0 + N_k$

    • $\beta_k = \beta_0 + N_k$

    • $m_k = \frac{1}{\beta_k}(\beta_0 m_0 + N_k {\bar x_k})$

    • $W_k^{-1} = W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}({\bar x_k} - m_0)({\bar x_k} - m_0)^T$

    • $\nu_k = \nu_0 + N_k$



  4. Estep: $q(Z) = \Pi_{n=1}^N \Pi_{k=1}^K r_{nk}^{z_{nk}}$を求める


    • $r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^K \rho_{nj}}$

    • $\ln\rho_{nk} = E[\ln\pi_k] + \frac{1}{2} E[\ln |\Lambda_k|] - \frac{D}{2}\ln(2\pi) - \frac{1}{2}E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)]$

    • $E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)] = D\beta_k^{-1} + \nu_k(x_n - m_k)^TW_k(x_n - m_k)$

    • $E[\ln|\Lambda_k|] = \sum_{i=1}^D \psi(\frac{\nu_k + 1 - i}{2}) + D \ln 2 + \ln|W_k|$

    • $E[\ln \pi_k] = \psi(\alpha_k) - \psi({\hat \alpha})$

    • ${\hat \alpha} = \sum_k \alpha_k$



  5. 収束するまで2~4をまわす


実装

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")

#次元
D = 2
#データ数
N = 2000

# 実際の値
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)

plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1),
np.random.multivariate_normal(mu2, sigma2, N2),
np.random.multivariate_normal(mu3, sigma3, N3)
])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()

w8dO9h9H66IZwAAAABJRU5ErkJggg==.png

#初期クラス数

K = 6
#初期値
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])

#事前分布のパラメータ
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)

#初期パラメータ
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))

def multi_gauss(x, y, mu, sigma):
return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))

#1: r_nkの初期化
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
r[:, k] = g[:, k] / g.sum(1)

# 図示
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)

plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png")
plt.show()

n9OnT9Pc3KzKdIpEGtT0IVbdmfJhbduhUGhT+oCIjkz3A8h8XxDVoLYXa+NDT09P2nyQK1AkEokEDY4MJRKJZDOQwVAikUiQwVAikUgAGQwlEokEkMFQIpFIABkMJRKJBJDBUCKRSAD4HwgVxm0OO01kAAAAAElFTkSuQmCC.png

for i in range(20):

#2: 三つの統計量を計算
N_k = r.sum(0)
mu = r.T.dot(data) / np.c_[N_k]
for k in range(K):
S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]

#3: Mstep
alpha = alpha_0 + N_k
beta = beta_0 + N_k
m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
for k in range(K):
tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
W_k[k] = LA.inv(tmp2)
nu_k = nu_0 + N_k

#4: Estep
E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
E_ln_pi = digamma(alpha) - digamma(alpha.sum())
for k in range(K):
E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
r = ro / np.c_[ro.sum(1)]
r[r < 1e-10] = 1e-10

# gifの図を作成
plt.figure(figsize=(5, 5))
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
pi = np.exp(E_ln_pi)
for k in range(K):
Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
if np.exp(E_ln_pi)[k] > 0.01:
plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
title = "iter: {}".format(i+1)
plt.title(title)
#plt.savefig("data/" + title + ".png")
plt.show()