LoginSignup
10
8

More than 5 years have passed since last update.

混合ガウス分布のEMアルゴリズムによる推定

Last updated at Posted at 2017-05-12
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
plt.style.use("ggplot")

K = 2
M = 1
N = 1000

# 実際の値
mu1 = 1
sigma1 = 0.2
N1 = int(N*0.3)
mu2 = 5
sigma2 = 1
N2 = int(N*0.7)

x = np.concatenate([np.random.normal(mu1, sigma1, N1), np.random.normal(mu2, sigma2, N2)])
plt.hist(x, bins=100, normed=True)
plt.show()

ASjFHlVqtP15AAAAAElFTkSuQmCC.png

# 初期値
w = np.random.uniform(0,1,K)
w /=  w.sum()
mu = [3, 10]
sigma = [1, 1]
ita = np.zeros([N, K])

# EM アルゴリズム
training_iter = 50
for epoch in range(training_iter):
    # Estep
    for k in range(K):
        ita[:, k] = w[k] * stats.norm(mu[k], sigma[k]).pdf(x)
    ita = ita/ita.sum(1)[:, np.newaxis]

    # Mstep
    w = ita.sum(0) / N 
    mu = ita.T.dot(x) / ita.sum(0)
    for k in range(K):
        sigma[k] = np.sqrt(((x - mu[k]) ** 2 * ita[:, k]).sum(0) / ita[:, k].sum() * M)
    # 図示    
    x_ = np.linspace(0, 8, 200)
    y0 = w[0] * stats.norm(mu[0], sigma[0]).pdf(x_) 
    y1 = w[1] * stats.norm(mu[1], sigma[1]).pdf(x_) 
    plt.plot(x_, y0)
    plt.plot(x_, y1)
    plt.hist(x, bins=100, normed=True)
    title = "epoch: {}".format(epoch+1)
    plt.title(title)
    # plt.savefig("data/" + title + ".png")
    plt.show()

タイトルなし.gif

10
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
8