概要
EMアルゴリズムのうち、一次元の混合正規分布に対するパラメータ推定についての解説になります。
EMアルゴリズムの数式を見て挫折しかけている方にとって、何かしらの参考になることを期待しています。
また、プログラムでデータの生成や計算を行う際は、なるべく数式そのものを自分でコーディングするスタイルで実装しています。
記事作成の背景
私は、数学を高校までしか習っておらず、大学の数学は独学。
ですが、とあるきっかけでEMアルゴリズムを学ぶことになりました。
そこで、続・わかりやすいパターン認識と、
以下の記事を読んだのですが、
それでも、一次元の混合正規分布に対してEMアルゴリズムを適用する場合の数式の導出方法がわかりませんでした。上記書籍も記事も非常にわかりやすかったのですが、私には前提知識が足りなかったからか、まだわからない所がかなり残っていました。
そんな状態で、その他色々な記事の記述を組み合わせたり、延々と式を書き直したりしてなんとか実装にこぎ着けたので記事化しました。
この記事の対象者
- 「続・わかりやすいパターン認識 教師なし学習入門」の第5章までは理解できた
- 「続・わかりやすいパターン認識 教師なし学習入門」の第6章第2節「log-sumからsum-logへ」まで、数式の展開は理解できた(完全には意味が分かっていなくてもなんとかなる)
上記の状態 1 の方が、第9章「混合分布のパラメータ推定」の第4節、第5節の混合正規分布のパラメータ推定(ただし一次元)を、理解した上でPythonで実装できるようにすることを目標として、この記事は書かれています。
前提
前記のような対象者を想定しているので、対数尤度関数や事後分布といった用語については説明しません。
その他にも、以下の項目については既に理解しているか、他の記事で補う前提で書かれています。
- 最尤推定法
- 対数尤度関数
- 事前分布 / 事後分布
- 偏微分
- ラグランジュの未定乗数法
- イェンゼンの不等式
- 正規分布の確率密度関数
- numpyやmatplotlibの使い方
また、変数の添え字や用語などは 「続・わかりやすいパターン認識 教師なし学習入門」 に合わせているので、同書籍はお手元にある前提になっています。
解説
使用するpythonモジュール
import numpy as np
import matplotlib.pyplot as plt
一次元正規分布の確率密度関数
まず、正規分布の確率密度関数を確認しておきます。
$$
\mathcal{N}(\mu, \sigma^2) = \frac{1}{\sqrt{2\pi} \sigma} e^{- \frac{(x - \mu)^2}{2\sigma^2}} ・・・①
$$
混合正規分布とは、この正規分布を複数種類組み合わせたものですね。
なので、平均値 $\mu$ と分散 $\sigma^2$ は複数種類、たとえば3種類あると仮定すると①の式は以下のようになります。
また、観測値 $x$ はn個得られたとして、k番目の観測値についての式としておきます。
\mathcal{N}(\mu_i, \sigma^2_i) = \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} (i \in \{1, 2, 3\})・・・②
この②式は後で使います。
ちなみに、p.170や171などでは以下のような記述がありますが、いずれも②と同じ意味合いになります。
\begin{align}
p(x|\omega_i) &= \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x - \mu_i)^2}{2\sigma^2_i}} (9\cdot2) \\
\\
p(x_k|\omega_i; \boldsymbol{\theta}) &
\end{align}
ここで、正規分布を1つ生成する関数を作成します。
def normal_dist(x: np.ndarray, mu: float, sigma: float) -> np.ndarray:
"""
xから、正規分布に従うyを出力する
"""
denominator = (np.sqrt(2 * np.pi) * sigma) + 1e-8 # NOTE: to avoid division by zero
f_x: np.ndarray = 1 / denominator * np.exp(-(x - mu) ** 2 / (2 * sigma ** 2))
return f_x
試しに適当な正規分布を生成してみます。
x = np.linspace(-10, 10, 1001)
mu = 3.0
sigma = 2.0
norm = normal_dist(x, mu, sigma)
# 描画
plt.plot(x, norm)
plt.xlabel("x")
plt.ylabel("probability")
plt.show()
よさげ( *˙︶˙*)وグッ!
混合正規分布(一次元)
今回は、上記の正規分布3つを混合します。
どういう「混合」かというと、「平均値 $\mu$ と分散 $\sigma^2$ がそれぞれ異なる正規分布を、それぞれ重みづけて足し合わせた分布を作成する」という意味で「混合」という言葉が使われています。
3種類の正規分布の混合割合(重み)を、それぞれ $\pi_1, \pi_2, \pi_3$ とおきます。今回はたとえば、 $\pi_1 = 0.2, \pi_2 = 0.3, \pi_3 = 0.5$ として考えてみます。(混合した後も確率分布として扱える状態を維持したいので、合計が1にならないとダメです。)
すると、混合正規分布の確率密度関数は以下のようになります。
$$
p(x_k; \boldsymbol{\theta}) = p(x_k; \mu_i, \sigma_i^2) = \sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2) ・・・③
$$
これはp.171の(9·4)式と同じものを指しています。
p(x_k; \boldsymbol{\theta}) = \sum^{c}_{i=1} \pi_i p(x_k|\omega_i, \boldsymbol{\theta}_i) (k = 1, 2, ..., n) (9\cdot4)
※補足: (9·4)式について
この(9·4)式は導出過程が省略されているように感じたので、補足しておく。
p.171の5行目に記載されている以下の記述を活用するのでメモしておく。
$\pi_i$ は、各クラスの事前確率 $P(\omega_i)$ であり、c種の確率密度関数の混合比を表す
更に、p.175のベイズの定理の式(9·27)を活用することにより、(9·4) 式を導出できる。
\begin{align}
まずベイズの定&理に従って式を変形する。 \\
p(\omega_i|x_k; \boldsymbol{\theta}) &= \frac{p(x_k|\omega_i; \boldsymbol{\theta}_i)p(\omega_i)}{p(x_k; \boldsymbol{\theta})} ・・・A\\
\\
ここで、なぜ分&子の \theta にだけ添え字 i がついているのか気になるが、 \\
これはおそらく&、\omega_i という事象の発生が既に確認されているので、 \\
「\theta は \theta_i しかあり&えない」と考えてよいからだと思われる。 \\
そのまま次のよ&うに変形する。 \\
\\
&= \frac{p(x_k|\omega_i; \boldsymbol{\theta}_i)p(\omega_i)}{\sum^{c}_{i=1} p(x_k|\omega_i, \boldsymbol{\theta}_i)p(\omega_i)} \\
\\
ここまでがのベ&イズの定理。 \\
この式に、先程&メモした P(\omega_i) = \pi_i を代入すると、 \\
\\
&= \frac{p(x_k|\omega_i; \boldsymbol{\theta})\pi_i}{\sum^{c}_{i=1} p(x_k|\omega_i, \boldsymbol{\theta}_i)\pi_i} ・・・B\\
\\
ここで、A式の&分母とB式の分母は等号で結ぶことができるので、 \\
(9\cdot4) 式を導出&できる。 \\
p(x_k; \boldsymbol{\theta}) &= \sum^{c}_{i=1} \pi_i p(x_k|\omega_i, \boldsymbol{\theta}_i) (9\cdot4)
\end{align}
*補足終わり*
混合正規分布の描画
元に戻って、この③式のデータを生成しましょう。
先程の正規分布を生成する関数を活用します。
# 前提条件
# データは、x = -10 から x = 10 までの範囲で 500個 生成する
x = np.linspace(-10, 10, 500)
mu = [3.0, -2.0, 1.5]
sigma = [2.0, 0.5, 1.0]
pi = [0.2, 0.3, 0.5]
# 混合正規分布のデータを生成
norms = [
normal_dist(x, _mu, _sigma)
for _mu, _sigma in zip(mu, sigma)
]
mixed_norm = norms[0] * pi[0] + norms[1] * pi[1] + norms[2] * pi[2]
# 描画
plt.plot(x, mixed_norm)
plt.xlabel("x")
plt.ylabel("probability")
plt.show()
この関数はEMアルゴリズムの実装には必要ないんですけど、分布の形状のイメージを掴むために実装しておきました。
混合正規分布から観測値の取得
これはほぼコピーさせていただきました。2
def sample_from_mixtured_norm(
mu_true: np.ndarray, sigma_true: np.ndarray, N: int, pi_true: np.ndarray
) -> np.ndarray:
"""
Sample data from the mixtured normal distribution at random
Parameters
------
mu_true: Expectations
Type: np.ndarray
Shape: (K,) (K: Number of observed values)
sigma_true: Variances
Type: np.ndarray
Shape: (K,) (K: Number of kinds of Norm dist)
N: Number of observed values
Type: int
pi_true: rate of
Type: np.ndarray
Shape: (K,) (K: Number of kinds of Norm dist)
Return
------
Shape: (N,) (N: Number of observed values)
"""
K = len(pi_true)
assert K == len(mu_true)
assert K == len(sigma_true)
org_data = None
for i in range(K):
print("check: ", i, mu_true[i], sigma_true[i], np.linalg.det(sigma_true[i]))
size = int(N * pi_true[i])
# 正規分布から標本をサンプリング
sampled_data = np.random.multivariate_normal(mean=mu_true[i], cov=sigma_true[i], size=size)
if org_data is None:
org_data = np.c_[sampled_data, np.ones(size) * i]
else:
org_data = np.r_[org_data, np.c_[sampled_data, np.ones(size) * i]]
data = org_data[:, 0:-1].copy()
return data
この関数は、以下のようにしてデータを生成できます。
seed = 77
np.random.seed(seed)
N = 500
pi = np.array([0.6, 0.4])
mu = np.array([[3], [-1],])
sigma = np.array([[[1]], [[1]],])
sapmled_data = sample_from_mixtured_norm(mu, sigma, N, pi).reshape(-1)
plt.hist(sapmled_data, bins=30)
-1と3に頂点がある感じがしますね。
混合正規分布の対数尤度関数
EMアルゴリズムは、何らかの関数を最大化もしくは最小化するパラメータ $\boldsymbol{\theta}$ の推定のために用いられるアルゴリズムです。
今回は、正規分布の混合割合 $\pi_i$ 、各正規分布の平均値 $\mu_i$ と分散 $\sigma_i^2$ を推定したいので、最大化する対象は、混合正規分布の対数尤度関数になります。
p.173「〔1〕最適なパラメータ」の記載の通り、観測結果を $\textbf{x} = \{x_1, x_2, ..., x_k, ..., x_n\}$ とすると、混合正規分布の対数尤度関数の式は次のように導出します。
\begin{align}
p(\textbf {x}; \boldsymbol {\theta}) &= \prod^n_{k=1} p(x_k; \boldsymbol {\theta}) \\
\\
上記の尤度関数&より、対数尤度関数は\\
\\
\log {p}(\textbf {x}; \boldsymbol {\theta}) &= \log \prod^n_{k=1} p(x_k; \boldsymbol {\theta}) \\
&= \sum^{n}_{k=1} \log p(x_k;\boldsymbol{\theta})
\end{align}
ここに前掲の式③を代入すると
\begin{align}
\log {p}(\textbf {x}; \boldsymbol {\theta}) &= \sum^{n}_{k=1} \log p(x_k;\boldsymbol{\theta}) \\
&= \sum^{n}_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} ・・・④\\
\end{align}
となります。
これは、いわゆる log-sum と呼ばれる形らしく、p.101に書かれているように、一般的に最尤推定値を解析的に求めることはできません。
ここで、「続・わかりやすいパターン認識」p.101などで強く主張されている、EMアルゴリズムらしい対処法を用いることになります。具体的には、イェンゼンの不等式(p.297参照)を利用して、$\log{ \sum f(x)}$ の数式を $\sum {\log f(x)}$ の形へと変換し、解析的に対処できるようにします。
イェンゼンの不等式は、$f(x)$が下に凸の場合は、
f(p_1 x_1 + p_2 x_2 + ... + p_n x_n) \leq p_1 f(x_1) + p_2 f(x_2) + ... + p_n f(x_n)
のようになるが、$f(x)$が下に凸の場合は、不等号の向きが逆になることを考慮しつつ適用すると、
\begin{align}
\log {p}(\textbf {x}; \boldsymbol {\theta}) &= \sum^{n}_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} ・・・④ \\
\\
ここで、\pi_i = p_i & , f(\cdot) = \log (\cdot) と考えてイェンゼンの不等式を適用
\\
& \geq \sum^{n}_{k=1} \sum^{c(=3)}_{i=1} \pi_i \log \mathcal{N}(\mu_i, \sigma_i^2) ・・・⑤
\end{align}
上記のように、$\log$の中から $\sum$を取り出すことができました。
更に、この⑤に対して②を代入すると、対数尤度関数をプログラムで実装できるようになります。
※②再掲
\mathcal{N}(\mu_i, \sigma^2_i) = \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} (i \in \{1, 2, 3\})・・・②
さっそく代入します。
\begin{align}
\log {p}(\textbf {x}; \boldsymbol {\theta}) & \geq \sum^{n}_{k=1} \sum^{c(=3)}_{i=1} \pi_i \log \mathcal{N}(\mu_i, \sigma_i^2) ・・・⑤ \\
&= \sum^{n}_{k=1} \sum^{c(=3)}_{i=1} \pi_i \log{\frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}}} ・・・⑥\\
\end{align}
これで、対数尤度関数を手持ちの値で表現できるようになりました。
具体的には、$x_k, \pi_i, \mu_i, \sigma^2_i$ と、定数の $e, \pi$ で算出が可能です。
この⑥を実装します。 ...と、せっかく教科書に従って⑥を計算したんですけど、これを使うと期待した結果にならなかったので、④式に従ってプログラムを実装しました。
この際、先程混合正規分布からサンプリングした観測値が格納されているsapmled_data
に対して対数尤度の計算を行います。
def calc_likelihood(
x: np.ndarray, mu_hat: np.ndarray, sigma2_hat: np.ndarray, pi_hat: np.ndarray
) -> np.ndarray:
"""
Parameters
------
x: observed values
Type: np.ndarray
Shape: 1-dimension (Number of observed values)
mu_hat: Expectations
Type: np.ndarray
Shape: 1-dimension (Number of kinds of Norm dist)
sigma_hat:
Type: Variances
Shape: 1-dimension (Number of kinds of Norm dist)
"""
norm_pdf = normal_dist(x.reshape(-1, 1), mu_hat, sigma2_hat)
likelihood = pi_hat * norm_pdf
return likelihood
試しに呼び出してみるとこんな結果になりました。
# 各パラメータの推定値の初期値(仮)
mu_hat = np.array([-2, -3])
sigma_hat = np.array([1, 1])
pi_hat = np.array([0.6, 0.4])
# 対数尤度の計算 ④式に従って
likelihood: np.ndarray = calc_likelihood(sapmled_data, mu_hat, sigma_hat, pi_hat)
sum_likelihood = np.sum(likelihood, axis=1)
log_sum_likelihood = np.log(sum_likelihood)
sum_log_sum_likelihood = np.sum(log_sum_likelihood)
> -5238.29942012823
合ってるかわからないけど良さそう!(´^ω^`)
対数尤度の最大化
次は、これまで導出してきた対数尤度関数を最大化するパラメータ($\mu_i, \sigma^2_i, \pi_i$)を求めたい。
最尤推定の方法に従えば、$対数尤度関数の微分後の式=0$とおけば各パラメータの最尤推定値が求められそうだが、$\pi_i$ については制約条件があるので注意。
- $\mu_i, \sigma^2_i$ については、それぞれのパラメータで対数尤度関数⑦を偏微分した後に $偏微分後の式= 0$ とすることで、極値をとる点(最尤推定値)を求める。
- $\pi_i$ については制約条件があるので、ラグランジュの未定乗数法に従って極値をとる点(最尤推定値)を求める。
ラグランジュの未定乗数法を適用
制約条件として、p.174に記載の(9·18)式を活用して式Fを組み立てる。
\sum^c_{i = 1} \pi_i = 1 ・・・(9\cdot18)
\begin{align}
F &= 最大化したい対数尤度関数 - \lambda (制約条件の式) = 0 \\
F &= \log {p}(\textbf {x}; \boldsymbol {\theta}) - \lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) \\
\\
④、⑤、②&より \\
\\
&= \sum^{n}_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) ・・・⑦ \\
&\geq \sum^{n}_{k=1} \sum^{c(=3)}_{i=1} \pi_i \log \mathcal{N}(\mu_i, \sigma_i^2) - \lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) = 0\\
&= \sum^{n}_{k=1} \sum^{c(=3)}_{i=1} \pi_i \log \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} - \lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) = 0 ・・・⑧ \\
\end{align}
μの最尤推定
前出の⑦の式Fを $\mu_i$ で偏微分します。
(あれ?⑧いらんかった...)
\begin{align}
\frac{\partial}{\partial \mu_i} F &= \frac{\partial}{\partial \mu_i} \log {p}(\textbf {x}; \boldsymbol {\theta}) - \frac{\partial}{\partial \mu_i}\lambda \left(\sum^c_{i = 1} \pi_i - 1 \right)\\
&= \frac{\partial}{\partial \mu_i} \left[ \sum^n_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] - 0 \\
&= \sum^n_{k=1} \frac{\partial}{\partial \mu_i} \left[\log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] \\
ここで対数&関数に対する微分の公式を使って、 \\
&= \sum^n_{k=1} \frac {\frac{\partial}{\partial \mu_i} \left[{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right]}{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
\mu_i で偏微分&すると、 \mu_i を持たない項は全て 0 になるため \\
&= \sum^n_{k=1} \frac { { \frac{\partial}{\partial \mu_i} \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
ここで対数&関数に対する微分の公式を応用する\\
具体的には&、 (\log f)' = \frac {f'}{f} の両辺に f をかけて、\\
f' = f(\log& f)' という変換式を導出して適用すると、\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2) \frac{\partial}{\partial \mu_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{\partial}{\partial \mu_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] ・・・⑨ \\
\end{align}
と展開できる。先に②を使って $\frac{\partial}{\partial \mu_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]$ だけ計算すると、
\begin{align}
\frac{\partial}{\partial \mu_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] &= \frac{\partial}{\partial \mu_i} \log \left[\pi_i \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} \right] \\
&= \frac{\partial}{\partial \mu_i} \left[\log \pi_i - \log{\sqrt{2\pi}} - \log \sigma_i + \log e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} \right] \\
&= \frac{\partial}{\partial \mu_i} \left[\log \pi_i - \log{\sqrt{2\pi}} - \log \sigma_i - \frac{(x_k - \mu_i)^2}{2\sigma^2_i} \right] \\
&= 0 - 0 - 0 - \frac{\partial}{\partial \mu_i} \frac{(x_k - \mu_i)^2}{2\sigma^2_i} \\
&= - \frac{\partial}{\partial \mu_i} \frac{x^2_k - 2x_k \mu_i + \mu_i^2}{2\sigma^2_i} \\
&= - \frac{- 2x_k + 2 \mu_i}{2\sigma^2_i} = \frac{x_k - \mu_i}{\sigma^2_i} ・・・⑩\\
\end{align}
この⑩を⑨に代入すると
\begin{align}
\frac{\partial}{\partial \mu_i} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{x_k - \mu_i}{\sigma^2_i}
\end{align}
上記のようになる。
上記のように偏微分した結果が 0 になる点が最尤推定値になるので、
\begin{align}
\frac{\partial}{\partial \mu_i} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{x_k - \mu_i}{\sigma^2_i} = 0 ・・・⑪ \\
\end{align}
となる。
これが $\mu_i$ の最尤推定値を求める式である。
この続きを導出する前に、負担率
という考え方について触れておく。
負担率とは
色々な記事に目を通してみたのですが、多くの記事で、負担率
という言葉が使われています。
ある記事では、$\gamma(z_{nk})$ と書かれていたり、別の記事では $r_{nk}$ と書かれていたりします。
本記事においては、以下の⑫式が負担率に相当するのですが、本記事では表現の簡素化のために $z$ という変数を持ち出さずに、 $r$ で表します。
添え字については、「続・わかりやすいパターン認識」に合わせて $i$ がクラス(カテゴリ)の種類、 $k$ は観測結果の番号を表しているので、負担率は次のように表現されます。
r_{ki} = \frac { \pi_i \mathcal{N}(\mu_i, \sigma_i^2) }{\sum^{c}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} ・・・⑫
負担率のソースコード
プログラムはこうなりました。
def calc_responsibility(likelihood):
return likelihood.T / np.sum(likelihood, axis=1).T
⑫式に含まれる$\pi_i \mathcal{N}(\mu_i, \sigma_i^2)$は、④式の尤度をsum - log - sumする前の状態なので、ちょうどさっき作った calc_likelihood
関数の出力になります。
axis=1
は、クラス種別の方向です。この方向にsumすると、分母のshapeは(観測値の件数)
になります。
負担率はなぜ必要?
この負担率という考え方を敢えて持ち出さなくても計算は行えるのですが、これを持ち出すことによって数式を簡易に表現できるようになるので、この負担率という言葉と式が用いられているように思われます。
実際、この $r_{ki}$ は、 $\sigma^2$ や $\pi_i$ の最尤推定値の式の導出結果にも含まれているので、負担率という考え方を利用した方が、この後の数式もプログラムも省略しやすくなります。
負担率が表しているもの
負担率が何を表しているかというと、ある$x_k$が観測される確率密度のうち、$i$ カテゴリの分布による確率密度がどの程度その $x_k$ を発生させ得るか、を表しています。
という下手な説明ではわかりにくいので(´^ω^`)、これについては「EMアルゴリズム徹底解説」という記事の3−3. 負担率という項目を紹介しておきます。図による説明が非常にわかりやすいです。
負担率を使ってμの最尤推定値の式を導出
話を戻して、先程の負担率の式⑫を⑪に代入してμの最尤推定値を導出します。
\begin{align}
\frac{\partial}{\partial \mu_i} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{x_k - \mu_i}{\sigma^2_i} = 0 \\
&\Leftrightarrow \sum^n_{k=1} r_{ki} \cdot \frac{x_k - \mu_i}{\sigma^2_i} = 0 \\
&\Leftrightarrow \sum^n_{k=1} r_{ki} \cdot x_k = \sum^n_{k=1} r_{ki} \cdot \mu_i \\
&\Leftrightarrow \mu_i = \frac {\sum^n_{k=1} r_{ki} \cdot x_k}{\sum^n_{k=1} r_{ki}} \\
\end{align}
ここで ${\sum^n_{k=1} r_{ki}}$ は、先ほどの負担率の考え方に基づくと、全観測結果の内、どれだけ カテゴリ$i$によって観測結果が生じたかを表しているといえるので、
\begin{align}
\mu_i &= \frac {\sum^n_{k=1} r_{ki} \cdot x_k}{\sum^n_{k=1} r_{ki}} \\
&= \frac {1}{N_i}\sum^n_{k=1} r_{ki} \cdot x_k ・・・⑬ \\
\end{align}
と表せます。
$N_i$ は、全観測回数 $N$ のうち、カテゴリ $i$ の正規分布によって発生したであろう観測結果の数(割合ではなく、「数」なのでinteger型です)を表しています。
$N_i$の計算をしてから、
N_i = np.sum(r_i_k, axis=1)
N_i
> array([465.4043453, 34.5956547])
$\mu$の計算式も実装します。
tmp_mu_hat = np.sum(r_i_k * x, axis=1) / N_i
tmp_mu_hat
> array([-0.40820814, 5.49149433])
σ^2の最尤推定値
途中までは $\mu$ の最尤推定値導出と一緒です。
\begin{align}
\frac{\partial}{\partial \sigma_i^2} F &= \frac{\partial}{\partial \sigma_i^2} \log {p}(\textbf {x}; \boldsymbol {\theta}) -\frac{\partial}{\partial \sigma_i^2}\lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) \\
&= \frac{\partial}{\partial \sigma_i^2} \left[ \sum^n_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] - 0 \\
&= \sum^n_{k=1} \frac{\partial}{\partial \sigma_i^2} \left[\log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] \\
ここで対数&関数に対する微分の公式を使って、 \\
&= \sum^n_{k=1} \frac {\frac{\partial}{\partial \sigma_i^2} \left[{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right]}{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
\sigma_i^2 で偏微分&すると、 \sigma_i^2 を持たない項は全て 0 になるため \\
&= \sum^n_{k=1} \frac { { \frac{\partial}{\partial \sigma_i^2} \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
ここで対数&関数に対する微分の公式を応用する\\
具体的には&、 (\log f)' = \frac {f'}{f} の両辺に f をかけて、\\
f' = f(\log& f)' という変換式を導出して適用すると、\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2) \frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \\
\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] ・・・⑭ \\
\end{align}
なんと、ここまでは $\mu_i$ の最尤推定値を導出した際と全く同じ式変形です。
ここからは異なるのですが、②を使って $\frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]$ を計算すると、
\begin{align}
\frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] &= \frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \frac{1}{\sqrt{2\pi} \sigma_i} e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} \right] \\
&= \frac{\partial}{\partial \sigma_i^2} \left[\log \pi_i - \log{\sqrt{2\pi}} - \log \sigma_i + \log e^{- \frac{(x_k - \mu_i)^2}{2\sigma^2_i}} \right] \\
&= \frac{\partial}{\partial \sigma_i^2} \left[\log \pi_i - \log{\sqrt{2\pi}} - \log \sigma_i - \frac{(x_k - \mu_i)^2}{2\sigma^2_i} \right] \\
ちょっと無理やりですが&、\sigma_i^2 で微分したいので \sigma_i^2 = a とおいて式変形します \\
&= \frac{\partial}{\partial a} \left[\log \pi_i - \log{\sqrt{2\pi}} - \log \sqrt{a} - \frac{(x_k - \mu_i)^2}{2 a} \right]\\
&= 0 - 0 - \frac{1}{\sqrt {a} \cdot 2 \sqrt {a}} - \frac{(x_k - \mu_i)^2}{2} \frac{\partial}{\partial a} a^{-1} \\
&= - \frac{1}{2 a} - \frac{(x_k - \mu_i)^2}{2} (-1)a^{-2} \\
a = \sigma_i^2 を代入して元の&記号に戻します\\
&= - \frac{1}{2 \sigma^2_i} + \frac{(x_k - \mu_i)^2}{2} \sigma^{2 \cdot (-2)}_i\\
&= - \frac{1}{2 \sigma^2_i} + \frac{(x_k - \mu_i)^2}{2 \sigma_i^{4}}
・・・⑮ \\
\end{align}
では、⑭式に対して⑮と、負担率 $r_{ki}$ ⑫を代入し、$=0$とおいて最尤推定値を導出しましょう。
\begin{align}
\frac{\partial}{\partial \sigma_i^2} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{\partial}{\partial \sigma_i^2} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] ・・・⑭ \\
&= \sum^n_{k=1} r_{ki} \left[- \frac{1}{2 \sigma^2_i} + \frac{(x_k - \mu_i)^2}{2 \sigma_i^{4}} \right] = 0 \\
&\Leftrightarrow \sum^n_{k=1} r_{ki}\frac{1}{2 \sigma^2_i} = \sum^n_{k=1} r_{ki} \frac{(x_k - \mu_i)^2}{2 \sigma_i^{4}} \\
&\Leftrightarrow \sigma^2_i = \frac{\sum^n_{k=1} r_{ki} (x_k - \mu_i)^2}{\sum^n_{k=1} r_{ki}} \\
ここで、& \mu_iの最尤推定値導出時と同様に、 \sum^n_{k=1} r_{ki} = N_i とすると \\
&\Leftrightarrow \sigma^2_i = \frac{1}{N_i} \sum^n_{k=1} r_{ki} (x_k - \mu_i)^2 ・・・⑯ \\
\end{align}
$N_i$を利用して$\hat \sigma^2$の計算も実装します。
x_mu_diff =(x.reshape(-1, 1) - mu_hat).T
tmp_sigma_hat = np.sum(r_i_k * (x_mu_diff ** 2), axis=1) / N_i
tmp_sigma_hat
> array([35.3980779 , 81.28158419])
πの最尤推定値
前出の⑦の式Fを $\pi_i$ で偏微分します。
\begin{align}
\frac{\partial}{\partial \pi_i} F &= \frac{\partial}{\partial \pi_i} \log {p}(\textbf {x}; \boldsymbol {\theta}) -\frac{\partial}{\partial \pi_i}\lambda \left(\sum^c_{i = 1} \pi_i - 1 \right) \\
&= \frac{\partial}{\partial \pi_i} \left[ \sum^n_{k=1} \log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] - \lambda \\
&= \sum^n_{k=1} \frac{\partial}{\partial \pi_i} \left[\log {\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right] - \lambda \\
ここで対数&関数に対する微分の公式を使って、 \\
&= \sum^n_{k=1} \frac {\frac{\partial}{\partial \pi_i} \left[{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \right]}{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \\
\sigma_i^2 で偏微分&すると、 \sigma_i^2 を持たない項は全て 0 になるため \\
&= \sum^n_{k=1} \frac { { \frac{\partial}{\partial \pi_i} \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \\
ここで対数&関数に対する微分の公式を応用する\\
具体的には&、 (\log f)' = \frac {f'}{f} の両辺に f をかけて、\\
f' = f(\log& f)' という変換式を導出して適用すると、\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2) \frac{\partial}{\partial \pi_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \\
\\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac{\partial}{\partial \pi_i} \log \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right] - \lambda \\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac {\frac{\partial}{\partial \pi_i} \left[\pi_i \mathcal{N}(\mu_i, \sigma_i^2)\right]}{\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac {\mathcal{N}(\mu_i, \sigma_i^2)}{\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} - \lambda \\
&= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac {1}{\pi_i} - \lambda \\
\end{align}
今は、$\frac{\partial}{\partial \pi_i} F = 0$ となる点を求めるラグランジュの未定乗数法の途中なので、
\begin{align}
\frac{\partial}{\partial \pi_i} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac {1}{\pi_i} - \lambda = 0 ・・・⑰ \\
&\Leftrightarrow \sum^n_{k=1} \frac {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)}{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} = \lambda{\pi_i} \\
&\Leftrightarrow \sum^c_{i=1}\sum^n_{k=1} \frac {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)}{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} = \sum^c_{i=1}\lambda \pi_i \\
&\Leftrightarrow \sum^n_{k=1} \frac { \sum^c_{i=1}{\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} = \lambda \sum^c_{i=1}\pi_i \\
&\Leftrightarrow \sum^n_{k=1} 1 = \lambda \cdot 1 \\
&\Leftrightarrow N = \lambda ・・・⑱ \\
(※N, n は&観測データの総数)\\
ここで、&⑰に⑱と負担率⑫を代入すると \\
\frac{\partial}{\partial \pi_i} F &= \sum^n_{k=1} \frac { {\pi_i \mathcal{N}(\mu_i, \sigma_i^2)} }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} \cdot \frac {1}{\pi_i} - \lambda = 0 ・・・⑰ \\
&\Leftrightarrow \sum^n_{k=1} r_{ki} \cdot \frac {1}{\pi_i} = N \\
&\Leftrightarrow \pi_i = \frac {\sum^n_{k=1} r_{ki}}{N} ・・・⑲ \\
\end{align}
最後に$\pi_i$の実装です。
N = len(x)
tmp_pi_hat = N_i / N
tmp_pi_hat
> array([0.93080869, 0.06919131])
$\pi_i$は各正規分布の混合割合なので、合計1になっていないとおかしいですね。
この結果は良さげ。
これで、$\mu_i, \sigma^2_i, \pi_i$ の最尤推定値を計算できる状態になりました。
やっと必要な計算式が出そろった、ということです。お疲れさまでした(o_ _)o))
EMアルゴリズムを組み立てる
EMアルゴリズムでパラメータ推定を行うために必要な数式を並べておきます。
\begin{align}
r_{ki} &= \frac { \pi_i \mathcal{N}(\mu_i, \sigma_i^2) }{\sum^{c(=3)}_{i=1} \pi_i \mathcal{N}(\mu_i, \sigma_i^2)} ・・・⑫ \\
\mu_i &= \frac {1}{N_i}\sum^n_{k=1} r_{ki} \cdot x_k ・・・⑬ \\
\sigma^2_i &= \frac{1}{N_i} \sum^n_{k=1} r_{ki} (x_k - \mu_i)^2 ・・・⑯ \\
\pi_i &= \frac {\sum^n_{k=1} r_{ki}}{N} = \frac {N_i}{N} ・・・⑲ \\
\end{align}
こうして並べてよく見てみると分かるのですが、負担率 $r_{ki}$ の算出には $\pi_i, \mu_i, \sigma_i$ が必要ですし、パラメータ $\pi_i, \mu_i, \sigma_i$ 算出には 負担率 $r_{ki}$ が必要になっています。
実は、これらを交互に計算し続けていくのがEMアルゴリズムになります。
EMアルゴリズムの計算処理の概要 3 をまとめるとこんな感じになるようです。
- パラメータ初期値を設定
-
Eステップ
負担率を計算 -
Mステップ
- 各パラメータを最大化するための最尤推定を実行
最尤推定の処理中で負担率を使う - 各パラメータを新しい値で更新
- 各パラメータを最大化するための最尤推定を実行
-
終了判定
- 対数尤度を計算
- 対数尤度にほとんど変化が見られなくなったら処理終了
- 処理が続く場合は[2. Eステップ]に戻る
これら全体の実装はGithubにアップロードしておきました。以下を参照。
めちゃ時間かかった。
参考資料
EMアルゴリズム
理論
理論面での計算式の導出や、プログラム内の変数名 / 関数名を付ける際などに以下の記事を参照しました。
実装
実装で一番勉強になったのは以下の記事。
見たことないメソッドなどが使われていて最初は戸惑ったけれど、ソースコードが短くまとめられていて、読みやすかったです。
イェンゼンの不等式
p.297以外にも、以下のように様々な記事で解説されているので、受け入れやすい、理解しやすいものを探してもいいと思います。
TeX
いつもお世話になっております(o_ _)o))
あとがき1(動機など)
なぜこの記事を書いたかというと、「前提」に書いた項目を全部満たした状態でEMアルゴリズムの勉強を始めたのに、それでも理解するのに何十時間もかかったからです。
次に学ぶ人がこの時間をもっと短縮出来たら、科学技術の進歩も加速する...!(大袈裟)と期待しています。
(これくらい細かい式展開の補足を加えた記事が世の中に溢れればいいなぁ)(-人-)ナム
次はこの調子で、二次元の混合正規分布にもEMアルゴリズムを適用できるよう勉強していきます!これについてはいくつかの記事がネット上に挙がっているので、そちらに解説お任せします( ˘ω˘)スヤァ
あとがき2(大変だったこと)
EMアルゴリズムは、「EステップとMステップの繰り返しだよ」と色んな本や記事に書かれているけど、Eステップに関する説明の仕方が人によって違っていて、めちゃ混乱した。
以下のような感じ。
- Eステップ
少なくとも、以下の3つの記述が見受けられた。
...統一しませんか?みんな仲悪いんですか?- 負担率を計算する
- Q関数を計算する
- Zの事後分布を計算する
- Mステップ
求めた負担率(or Q関数 or Zの事後分布)を使って、パラメータを最尤推定。 - 収束判定
対数尤度の変化量が基準値を下回ったら処理終了。
ここは結構統一されていたように感じる。
でも、EステップでもMステップでも登場しない「対数尤度」って言葉がいきなり出てくるので困った。
よくよく一連の式を観察すると、尤度から負担率を求めることができるってわかるんですけど、そのことに気が付くまでかなりの時間がかかった。
あとは、最初はランダムな観測値を用意してなくて、混合正規分布の確率密度関数をそのまま観測値として計算させていたり、色々な勘違いがあって相当な時間を費やした。
だからこそ期待通りの動作が確認できた今、満足感は非常に高い(*˘︶˘*)
-
「上記の状態」というのは、実はこの記事を書き始めた日の前日の私です(´^ω^`)ぜんぜんわかってなかった ↩
-
EMアルゴリズム徹底解説 の記事を作成された方のソースコードを参考にしました ↩
-
この辺りの記述は、「続・わかりやすいパターン認識」や各種記事などでそれぞれに異なっていて、混乱しました。詳しくは「あとがき2」を参照されたい_(-ω-`_)⌒)_ ↩