Help us understand the problem. What is going on with this article?

(機械学習)混合ガウス分布におけるEMアルゴリズムを実装を交えて丁寧に理解しようとした

はじめに

 現在、ベイズ推論について勉強しています。今回、混合ガウス分布におけるEMアルゴリズムについて私の理解を記しておこうと思います。

 勉強を進めていく中で、簡単な例で良いので図示したり具体化しながら考えることで式や言葉の理解が非常に早く進むなぁ、と感じています。
 なので、なるべく実装を交えて理解しやすい記事にできたら良いなと思います。

 今回こちらの記事を多大に参考にさせて頂きました。概念から式変形、実装に至るまで非常にわかりやすくまとめられております。

EMアルゴリズム徹底解説
https://qiita.com/kenmatsu4/items/59ea3e5dfa3d4c161efb

EMアルゴリズムとは

 EMアルゴリズム(Expectation Maximazation algorithm)とは、隠れ変数を含むモデルの学習・最適化に使われるアルゴリズムのことです。

混合係数

 用語の解説としてまずは混合係数について理解を深めます。

 下記のような2次元の観測$x$が得られた場合の$x$の確率モデル$p(x)$を考えます。この時、2つのクラスタA,Bから生成されているようにみえるため、これを反映させたモデルを考えます。

image.png

 ガウス分布によって決まるとしたとき、下記のように表すことができます。

\begin{align}
p(x) &= \pi_A\mathcal N(x|\mu_A, \Sigma_A) +\pi_B\mathcal N(x|\mu_B, \Sigma_B)\\

\end{align}

但し、

  • $x$:$D$次元のデータ
  • $\mathcal{D}={x_1,\cdots, x_N}$ : $N$個の観測点(データ集合)
  • $μ$:$D$次元の平均ベクトル
  • $Σ$:$D×D$の共分散行列
  • $\mathcal N(x|μ,Σ)$:$D$次元のガウス分布

とします。一般化すると下記です。

\begin{align}
p(x) &= \sum_{k=1}^{K} \pi_k\mathcal N(x|\mu_k, \Sigma_k) \hspace{1cm}(式1)

\end{align}

この$π_k$を混合係数(mixing coefficient)と呼び、下記を満たします。

\sum_{k=1}^{K} π_k =1\\

0 \leqq π_k \leqq 1

但し、$K$:クラスタ数とします。つまり、混合係数とは各クラスターにおける重み(=どのクラスターが一番存在確率が高いか)を表す数値になります。

負担率

 次に、負担率という言葉を考えます。
 $π_k$=$p(k)$を$k$番目のクラスタを選択する事前確率とし
$\mathcal N(x|\mu_k, \Sigma_k)=p(x|k)$を$k$が与えられた時の$x$の条件付き確率とすると、$x$の周辺密度は

p(x) = \sum_{k=1}^{K} p(k)p(x|k)\hspace{1cm}(式2)\\

と表すことができる。これは先ほどの式$1$と等価です。
さて、この時の$p(k|x)$を負担率と呼びます。この負担率は$γ_k(x)$とも表し、ベイズの定理を用いて、

\begin{align}
γ_k(x) &\equiv p(k|x)\\
&=\frac {p(k)p(x|k)}{\sum_lp(l)p(x|l)}\\
&=\frac {π_k\mathcal N(x|\mu_k, \Sigma_k)}{\sum_lπ_l\mathcal N(x|\mu_l, \Sigma_l)}

\end{align}

と表すことができます。この負担率は何を意味するのでしょうか。
実装しながら確認したいと思います。

実装して確認してみる

EM.ipynb
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy import stats as st
# ======================================
# Parameters
K = 4
n = 300
xx = np.linspace(-4, 10, n)

mu = [-1.2, 0.4, 2, 4]
sigma = [1.1, 0.7, 1.5, 2.0]
pi = [0.2, 0.3, 0.3, 0.2]

# Density function
pdfs = np.zeros((n, K))
for k in range(K):
    pdfs[:, k] = pi[k]*st.norm.pdf(xx, loc=mu[k], scale=sigma[k])

# =======================================
# Visualization
plt.figure(figsize=(14, 6))
for k in range(K):
    plt.plot(xx, pdfs[:, k])
plt.title("pdfs")
plt.show()

plt.figure(figsize=(14, 6))
plt.stackplot(xx, pdfs[:, 0], pdfs[:, 1], pdfs[:, 2], pdfs[:, 3])
plt.title("stacked")
plt.show()

image.png

負担率とは、上記図で見ると分かるようにある地点$x$が与えられた時の混合ガウス分布の密度関数の中値で、それぞれクラスターに対して$k$が占める割合となります。

ガウス分布について寄り道(共分散行列の3種類)

 さて、ガウス分布を学ぶ上で共分散行列$Σ$について、$Σ=σ^2I$と置かれて計算が進められている場合があります。
 これは何を意味しているのかを考えるために、共分散行列の3種類についてまとめておきます。

一般対称共分散行列

 $D$次元のガウス分布を考えます。この時、共分散行列$Σ$は下記のように表すことができます。

\Sigma = \begin{pmatrix}
σ_{1}^2 & σ_{12} & ・・・& σ_{1D}^2 \\
σ_{12} & σ_{2}^2\\
・\\
・\\
・\\
σ_{1D}& σ_{22} & ・・・& σ_{D}^2\\
\end{pmatrix}\\

このような共分散行列を一般対称型と呼びます。この共分散行列は$D×(D+1)/2$個の自由パラメータがあります(上記の行列の変数を数えて求めます)。

共分散行列が対角

次に、共分散行列が対角となる場合を考えます。

\Sigma =diag(σ_i^2)=\begin{pmatrix}
σ_{1}^2 & 0 & ・・・& 0 \\
0 & σ_{2}^2\\
・\\
・\\
・\\
0& 0 & ・・・& σ_{D}^2\\
\end{pmatrix}\\

この場合、自由パラメータは次元数と等しく$D$個となります。

共分散行列が単位行列に比例(等方)

最後に、共分散行列が単位行列に比例する場合を考えます。これは、等方共分散行列と呼びます。

\Sigma =σ^2\bf I= σ^2\begin{pmatrix}
1 & 0 & ・・・& 0 \\
0 & 1\\
・\\
・\\
・\\
0& 0 & ・・・& 1\\
\end{pmatrix}\\

このような場合は自由パラメータが$σ$の一つだけになります。さて、これらの3つの場合について確率密度を下記に示します。

image.png

自由パラメータの数が少なくなることで計算が簡単になるため、計算速度が上がります。一方で、確率密度の表現力が下がることが分かります。一般対称の場合では、対角や等方的な形も表すこともできます。
 
 計算を速く行うことと表現力を担保することを両立するために、潜在変数や非観測変数を導入することで解決しようとする考え方があります。
 この潜在変数及び複数のガウス分布(=混合ガウス分布)によって表現力を高めることがよく行われています。

混合ガウス分布へのEMアルゴリズム適用

 さて、本論に戻ります。今回題材としているEMアルゴリズムは下記の3.に当たります。

image.png

潜在変数$\bf z^T$を行ベクトルとする$N×K$行列$\bf Z$としたとき、対数尤度関数は以下のように表すことができます。

\begin{align}
ln \hspace{1mm} p(\bf X| π,μ,Σ) &=\sum_{n=1}^{N}ln\Bigl\{ \sum_{k=1}^{K} \pi_k\mathcal N(x|\mu_k, \Sigma_k) \Bigr\}

\end{align}

さて、この対数尤度関数を最大化させることで未知のデータ$x$に対して高確率で予測することが可能になります。つまり、最尤関数を求めるということになります。

尤度という考え方の詳しい内容はこちら非常に詳しく参考にさせて頂きました。

【統計学】尤度って何?をグラフィカルに説明してみる。
https://qiita.com/kenmatsu4/items/b28d1b3b3d291d0cc698

 ただ、今回の関数最尤化を解析的に行うことは非常に難しいです($log-Σ$は解くことが非常に困難)。そこで、EMアルゴリズムと呼ばれる方法で求めることを考えます。

 混合ガウス分布におけるEMアルゴリズムは下記の通りになります。

image.png

混合ガウスモデルが与えられているとき(あるいは自身で設定したとき)、各要素の平均、分散及び混合係数のパラメータを尤度関数が最大となるように調整することが目的です。

ニューラルネットワークにおいて重みパラメータを更新することに類似していると考えました。重みパラメータは、損失関数の勾配を用いて更新して損失関数が最小となる重みパラメータを求めます。
 このとき、損失関数自体の勾配(=微分)は数学的に求めることは計算量が膨大となってしまいます。従って、逆誤差伝播法と呼ばれるアルゴリズムで勾配を求めます。

 これと同じように、EMアルゴリズムの場合は$π、μ、Σ$をそれぞれ計算して更新させます。そして、収束判定させて基準を満たしていれば最尤関数とするのです。

パラメータ更新

 混合ガウス分布におけるEMアルゴリズムでは、負担率$γ(z_{nk})$、混合係数$π_k$、平均$μ_k$、共分散行列$Σ_k$を更新する必要があります。
 この計算は微分して0とおいてゴリゴリ解いていく必要があります。解き方に関してはこちらの記事が非常に詳しいので、見て頂けると幸甚です。

EMアルゴリズム徹底解説
https://qiita.com/kenmatsu4/items/59ea3e5dfa3d4c161efb

結果として、各パラメータは下記のように表すことができます。

γ(z_{nk}) =\frac {π_k\mathcal N(x_n|\mu_k, \Sigma_k)}{\sum_{l=1}^{K}π_l\mathcal N(x|\mu_l, \Sigma_l)}\\

π_k = \frac {N_k}{N}\\
μ_k = \frac {1}{N_k}\sum_{n=1}^{N}γ(z_{nk})\bf{ x_n}\\
\Sigma_k = \frac{1}{N_k}\sum_{n=1}^{N}γ(z_{nk})(\bf x_n -μ_k)(\bf x_n -μ_k)^T\\

見て頂いて分かるかと思いますが、$π、μ、Σ$は全て$γ(z_{nk})$に依存しています。
 また、単一のガウス分布における最尤関数を求めようとするときはこの$γ(z_{nk})$が1となった時の値です。
 すると、それぞれ単純に平均値や共分散を求めているだけなので、理解しやすいのではないでしょうか。

実装してみる

プログラムはこちらに格納しております。
https://github.com/Fumio-eisan/EM_20200509/upload

gmm_anim.gif

点は平均値、実線は確率密度分布の等高線になっています。徐々に各クラスタごとに最適化されていくのが分かります。

終わりに

 今回、混合ガウス分布におけるEMアルゴリズムをまとめました。数学的な理解の前に視覚的に知ることで理解が進みやすいと思いました。

 式展開及び実装面は私自身弱いので、もっと手を動かして理解を深めます。また、一般的なEMアルゴリズムについて記事を書きたいと思います。

プログラムはこちらに格納しております。
https://github.com/Fumio-eisan/EM_20200509/upload

Fumio-eisan
製造業で技術者として働いています。昨今の機械学習ブームをきっかけに興味を持ちました。学んだことの定着率を上げるために記事としてまとめていきます。
https://fumio-eisan.hatenablog.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした