LoginSignup
12
24

書籍『Pythonでスラスラわかる ベイズ推論「超」入門』潜在変数モデル補足

Last updated at Posted at 2024-01-07

はじめに

書籍『Pythonでスラスラわかる ベイズ推論「超」入門』著者です。
読者からいただいた質問のうち、実装コード付きで示した方がいいものがあったので、この場で解説をしたいと思います。
なお、当記事は、書籍で説明している概念については、すべてわかっている前提で記載しています。わからない部分はすべて書籍内に説明がありますので、関心を持たれた読者は是非、本編の書籍もお買い求めいただけるとありがたいです。
書籍サポートサイトのリンクは下記になります。
https://bit.ly/3uV4i3R

いただいた質問と直接の回答

いただいた質問は、5.4節 潜在変数モデルの実習コードに対するものです。
以下のコードはサポートサイト上にもアップしておきました。
https://bit.ly/3vjwY6M

潜在変数モデルでは、特別な工夫をしないと「ラベルスイッチ」と呼ばれる事象が発生します。
その対策として、書籍内の実習コードではサンプリング関数呼び出し時にchains=1のパラメータを付け、ラベルスイッチが起きない工夫をしています。

with model1:
    idata1 = pm.sample(chains=1, draws=2000, target_accept=0.99,
      random_seed=42)

読者からいただいた質問の内容としては、下記のようにchainsとdrawsのパラメータを落としてサンプリングをしたが(この呼び出し方をするとchains=2, draws=1000を指定するのと同じになります)、ラベルスイッチは特に起きていないように見える。これはどうしてかというものでした。

with model1:
    idata1_1 = pm.sample(target_accept=0.99, random_seed=42)
az.plot_trace(idata1_2, var_names=['p', 'mus', 'sigmas'], compact=False)
plt.tight_layout();

質問に直接答えると、ここで起きている事象は読者の指摘のとおりで、このパラメータの組み合わせの場合、ラベルスイッチは起きていません。
2つのchainでラベルスイッチが起きるかどうかは、ほぼ確率1/2です。このパラメータの場合、たまたまラベルスイッチが起きないパラメータの組み合わせだったということができます。

発展的な実験

質問への回答という意味では以上で終わっているのですが、この話はより発展的な話をいろいろと含んでいます。
それを試すために次のようなパラメータでサンプリングをしてみます。

with model1:
    idata1_2 = pm.sample(target_accept=0.99, chains=5, random_seed=42)

上のサンプリングと比較すると、chainsの値だけ2から5に変更し、残りのパラメータはまったく同一です。

5つのchainのグラフを同時に表示するとわかりにくいので、chainの値でフィルターをかけて、[0, 1], [2], [3, 4]の3つのグループで分けてグラフ表示してみます。

[0, 1]のグループ

chainでフィルターをかけたい場合、coords={"chain": [0, 1]}のようなパラメータを追加します。
結果は、一つ上のグラフとまったく同一になります。

# 最初の2つのchainのみ抽出
az.plot_trace(idata1_2, var_names=['p', 'mus', 'sigmas'], 
 coords={"chain": [0, 1]}, compact=False)
plt.tight_layout();

[2]のグループ

では、その次のchainはどのようなことになっているでしょうか。

3番目のchainのみ抽出
az.plot_trace(idata1_2, var_names=['p', 'mus', 'sigmas'], coords={"chain": [2]}, compact=False)
plt.tight_layout();

面白い結果になりました。右下のmus1のグラフが一番わかりやすいので、それを使って解説すると、繰り返し回数200までは、おおよそ1.5程度の値だったのが、400回以降では2.0程度の値に変わっています。
つまり、このchainでは1回のchain内でラベルスイッチが発生したということになります。

[3, 4]のグループ

最後は[3, 4]のグループです。前と同様に実装コードと結果グラフを示します。

# 最後の2つのchainのみ抽出
az.plot_trace(idata1_2, var_names=['p', 'mus', 'sigmas'], 
  coords={"chain": [3, 4]}, compact=False)
plt.tight_layout();

この2回分だけ見るとラベルスイッチはないのですが、2つ前のグラフを比較するとmus0とmus1の値の傾向がまったく逆になっていることがわかります。
これがラベルスイッチの事象ということになります。

12
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
24