LoginSignup
0
2

More than 1 year has passed since last update.

NumPyroで階層ベイズをやってみる

Last updated at Posted at 2023-04-21

ここでは、NumPyroで階層ベイズを行う方法について紹介する。NumPyroはJAXをバックエンドにする確率的プログラミング言語であり、推論が爆速であることで知られる。

次のようなソーシャルメディアのデータ(ぶっちゃけtwitter)について分析することを想定したい。
ユーザーの行動データが複数あり、ユーザーが他のユーザーをフォローするなどコミュニティが形成される場合、コミュニティレベル、ユーザーレベル、ユーザーの行動レベルと3つの階層が作られることが考えられる。独立変数はユーザーが閲覧したテキストの表現であり、従属変数はユーザーが書いたテキストの表現であるとする(例えば、辞書や教師あり学習でポジネガ分類をした)。

独立変数: x1, x2、 x1とx2の交互作用
従属変数: y(ある行動をするかどうか0, 1)
階層構造: ユーザーレベル、コミュニティレベル

その場合、コードは下記のようになる。

def model(N, User_N,  User, Community_N ,Community2User, x1, x2, y = None):
    a_0 = numpyro.sample("a_0", dist.Normal(0, 10))
    sigma_a = numpyro.sample("sigma", dist.HalfNormal(5))
    b1_0 = numpyro.sample("b1_0", dist.Normal(0, 10))
    sigma_b1 = numpyro.sample("sigma_b1", dist.HalfNormal(5))
    b2_0 = numpyro.sample("b2_0", dist.Normal(0, 10))
    sigma_b2 = numpyro.sample("sigma_b2", dist.HalfNormal(5))
    b_interaction_0 = numpyro.sample("b_interaction_0", dist.Normal(0, 10))
    sigma_b_interaction = numpyro.sample("sigma_b_interaction", dist.HalfNormal(5))
    sigma_a_community = numpyro.sample("sigma_a_community", dist.HalfNormal(5))
    sigma_b1_community = numpyro.sample("sigma_b1_community", dist.HalfNormal(5))
    sigma_b2_community = numpyro.sample("sigma_b2_community", dist.HalfNormal(5))
    sigma_b_interaction_community = numpyro.sample("sigma_b_interaction_community", dist.HalfNormal(5))

    with numpyro.plate("Communitiy_N", size=Community_N):
        a_community = numpyro.sample("a_community", dist.Normal(a_0, sigma_a))
        b1_community  = numpyro.sample("b1_community", dist.Normal(b1_0, sigma_b1))
        b2_community  = numpyro.sample("b2_community", dist.Normal(b2_0, sigma_b2))
        b_interaction_community  = numpyro.sample("b_interaction_community", dist.Normal(b_interaction_0, sigma_b_interaction))
    
    with numpyro.plate("User_N", size=User_N):
        a = numpyro.sample("a", dist.Normal(a_community[Community2User], sigma_a_community))
        b1  = numpyro.sample("b1", dist.Normal(b1_community[Community2User], sigma_b1_community))
        b2  = numpyro.sample("b2", dist.Normal(b2_community[Community2User], sigma_b2_community))
        b_interaction  = numpyro.sample("b_interaction", dist.Normal(b_interaction_community[Community2User], sigma_b_interaction_community))
    q = sigmoid(a[User] + b1[User]*x1 +  b2[User]*x2 +  b_interaction[User]*x1*x2)   

    with numpyro.plate("N", size = N):
        y_pred = numpyro.sample("y_pred", dist.Bernoulli(q), obs=y)

事前分布の選び方はGelmanの記事を参考にすればいい。下記の記事では5つのレベルの事前分布について書いている。

  • 一様分布(推奨されない)
  • 非常に曖昧な事前分布: normal(0, 1e6)(推奨されない)
  • 弱情報事前分布: normal(0, 10)
  • 一般的な弱事前情報事前分布: normal(0.4, 0.2)
  • 特定の情報事前分布

グラフィカルモデルの描画

NumPyroの利点として、グラフィカルモデルを描画できることが挙げられる。

numpyro.render_model(model, model_args=(N, User_N,  User, Community_N ,Community2User, x1, x2, y = None), render_params=True, filename="model.png")

image.png

MCMC

事後分布を求めるため、MCMCを行う。NumPyroはPyroのMCMC100倍早いらしい。

kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_samples=2000, num_warmup=300, num_chains=4, chain_method="parallel")
mcmc.run(jax.random.PRNGKey(0), N, User_N,  User, Community_N ,Community2User, x1, x2, y)

変分推論

データ量が多い、モデルが複雑であるという場合は、変分推論を使う。

guide = infer.autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, N, User_N,  User, Community_N ,Community2User, x1, x2, y)
params = svi_result.params

num_samples = 1000
dist_posterior = Predictive(model=guide, params=params, num_samples=num_samples)
samples_posterior = dist_posterior(random.PRNGKey(1))
samples_posterior
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2