1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyMC3を用いた混合ガウスモデルのパラメータ推定

Posted at

はじめに

この記事は、note 「パソコンが3時間掛けてノイズを作った」の詳細を解説した記事です。
ざっくりな形です。細かく書くとただ既存の本の内容を書き写すだけになるので、もっと細かくな方は参考をご覧いただければと思います。

動作環境

MCMC(マルコフ連鎖モンテカルロ法)

  • モンテカルロ法は乱数を用いた手法の総称
  • 乱数の発生にマルコフ連鎖を用いる

これがMCMC

マルコフ連鎖

取りうる状態の全てを$S$、時点を$t = 1,2,...$とし、過去に起こった状態の履歴が与えられた下での条件付確率として表現すると、$t$において状態が$i(i \in S)$である人が次の$t + 1$において状態$j$となる確率は

P(x^{(t+1)} = j | x^{(1)} = i_1, x^{(2)} = i_2, ..., x^{(t)} = i)

と表せる。この条件付確率が

P(x^{(t+1)} = j | x^{(1)} = i_1, x^{(2)} = i_2, ..., x^{(t)} = i) = P(x^{(t+1)} = j | x^{(t)} = i)

という性質をすべての$t$について満たすならば、その確率過程はマルコフ連鎖と呼びます。
未来の状態$x^{(t+1)}$は現在の状態$x^{(t)}$のみに依存しており、過去の状態は未来の状態へ影響のないものとなっています。

$S = \lbrace 1, 2 \rbrace$とし、
現在が1の時、未来で1になる確率を0.2、2になる確率を0.8とします。
現在が2の時、未来で1になる確率を0.5、2になる確率を0.5とします。

$x^{(t)}$から$x^{(t+1)}$に推移する際の条件付確率$P(x^{(t+1)} = j | x^{(t)} = i) = P_{ij}$とすると、
上記の例では$P_{12} = 0.8, P_{21} = 0.5$となります。
$P_{ij}$を行列$\boldsymbol{P}$の要素とすると

\boldsymbol{P} = P_{ij} = \begin{pmatrix}
0.2 & 0.8 \\
0.5 & 0.5 \\
\end{pmatrix}

となります。

そして、$S$の各要素から1回目に選択する確率をベクトル$\boldsymbol{\pi}^{(1)} = (\pi_1^{(1)}, \pi_2^{(1)})$とします。
2回目は$\boldsymbol{\pi}^{(2)} = (\pi_1^{(2)}, \pi_2^{(2)})$となります。
2回目の$\boldsymbol{\pi}^{(2)}$は$\boldsymbol{\pi^{(1)}} \boldsymbol{P} = \boldsymbol{\pi^{(2)}}$で求めることができます。
これは回数にかかわらず$\boldsymbol{\pi^{(t)}} \boldsymbol{P} = \boldsymbol{\pi^{(t+1)}}$となります。

$x^{(t)}$から$x^{(t+1)}$,$x^{(t+1)}$から$x^{(t+2)}...x^{(t+k)}$から$x^{(t+k+1)}$のように推移を長い時間繰り返した過程を全体としてみると、共通の分布を形成することがあります。
つまりこれ以上回数を重ねても$\boldsymbol{\pi}$の中身が変わらないということです。
これはつまり、$\boldsymbol{\pi} \boldsymbol{P} = \boldsymbol{\pi}$が成立していることになります。
この$\boldsymbol{\pi}$を推移行列$\boldsymbol{P}$を持つマルコフ連鎖の不変分布と呼びます。

また、マルコフ連鎖の性質には「既約的」「正再帰的」「非周期的」なる性質があり、すべてを満たすマルコフ連鎖はエルゴード性を見たし、エルゴード的であると言われます。

モンテカルロ法

平均$\mu$、分散$\sigma^2$の分布に独立に従う確率変数列$\lbrace x^{(t)} \rbrace$がある時、その標本平均$\bar{x}^{(t)} = \frac{1}{t} \sum_{i=1}^{t} x^{(i)}$については以下の式(大数の弱法則)が成立します。

\lim_{t \to \infty}	P(|\bar{x}^{(t)} - \mu| < \epsilon) = 1 \quad \quad (\epsilon > 0)

標本サイズをふやせば、標本平均$\bar{x}$は平均$\mu$に近づいていくということです。
これにより平均を標本平均で近似することができます。
また、任意の連続関数$h(\cdot)$によって$x$を変換した場合でも、この定理は成立します。

ベイズ推論における母数の推定値としては、事後分布の期待値を計算する際、積分を求める必要があります。
積分の対象となる関数を$h(x)$とします。
独立に同一の分布$p(x)$に従う乱数$x^{(1)},x^{(2)}, ... , x^{(T)}$を生成することができれば、積分は

\int h(x)p(x)dx \approx \frac{1}{T} \lbrace h(x^{(1)}) + h(x^{(2)}) + \cdots + h(x^{(T)}) \rbrace

と近似できます。

大数の法則から、$T \rightarrow \infty$のときの平均による近似は、対象の積分の値に確率収束します。
このようにして、乱数を使って積分を数値的に求める方法をモンテカルロ積分と呼びます。

サンプリング方法としては棄却サンプリングや重点的サンプリングがあります。

PyMC3

不偏分布が目標とする分布であるようなエルゴード的なマルコフ連鎖を構成し、推移を繰り返せば、目標分布を得ることができるというのがMCMCの概略です。

そしてPyMC3はMCMCを用いてベイズ推論ができるライブラリです。
公式チュートリアルや、Bayesian Methods for Hackersに使い方が沢山記載されています。

今回の例では、混合ガウスモデルとして、3つの正規分布から成る分布と仮定し事後分布を推測しています。
Marginalized Gaussian Mixture Modelを参考にしました。
私自身まだ分からない部分が多いので、解釈間違い等ありましたら指定いただけると嬉しいです。

まず、得られたデータは以下画像の青い部分になります。

正規分布が3つ_1.jpg

得られたデータを標準化します。

# ランダムに5000個のデータを取り出す
rng = np.random.default_rng()
obs = rng.choice(audio_data, 5000, replace=False)

# 標準化
scaler = StandardScaler()
scaler.fit(obs.reshape(-1, 1))
obs_standard = scaler.transform(obs.reshape(-1, 1)).reshape(-1)
# 標準化データ保存
pickle.dump(scaler, open('data_folder/obs_standard_data.pkl', 'wb'))

各分布の採択確率として3次元のディリクレ分布
各分布の$\mu$として正規分布。初期値として$[-1.0, 0.0, 1.5]$と設定
各分布の$\tau (\tau = 1/\sigma^2)$としてガンマ分布
混合ガウス分布にそれぞれのパラメータと得られたデータをあてます。

estimate_shape = 3
with pm.Model() as model:
    w = pm.Dirichlet('w', np.ones(estimate_shape))
    
    mu = pm.Normal('mu', mu=np.zeros(estimate_shape), sigma=1.0, shape=estimate_shape, transform=pm.transforms.ordered, initval=[-1.0, 0.0, 1.5])
    tau = pm.Gamma('tau', alpha=1.0, beta=1.0, shape=estimate_shape)
    
    x_obs = pm.NormalMixture('x_obs', w, mu, tau=tau, observed=obs_standard)

不偏分布に至るまでのチューニングとして2000
その後のサンプルとして5000
チェーン数を16としています(わざとすべて使っています)

with model:
    trace = pm.sample(5000, n_init=10000, tune=2000, return_inferencedata=True, cores=16)

トレース結果を確認します。

pm.summary(trace)

az.plot_trace(d, var_names=['w','mu','tau'], figsize=(14, 8))
plt.tight_layout()
plt.show()

r_hatの値から収束していると判断します。

SnapCrab_NoName_2022-2-13_16-50-58_No-00.png
ダウンロード.png

推測したモデルから事後分布予測サンプルを生成します。

with model:
    ppc_trace = pm.sample_posterior_predictive(trace, model=model, var_names=['x_obs'], keep_size=True, progressbar=True)
trace.add_groups(posterior_predictive=ppc_trace)


x = trace.posterior_predictive.to_array()
x = x.to_numpy()
x_all = x.reshape(-1)

# ランダムに288000個のデータを取り出す
rng = np.random.default_rng()
x_pred_obs = rng.choice(x_all, 288000, replace=True)

サンプルを描画します。

plt.figure(figsize=(8, 8))
sns.histplot(x_pred_obs)
plt.xlabel('音量(標準化)')
plt.title('取得結果分布')
plt.show()

ダウンロード (1).png

元のサンプルと比較します。概ね合致していると思います。

比較.png

参考

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?