LoginSignup
4
6

More than 1 year has passed since last update.

ベイズ統計の勉強が捗る pymc3のケーススタディ(Bayesian mediation analysis) 解説

Last updated at Posted at 2022-02-06

はじめに

本内容はpymc3 Bayesian mediation analysisの解説記事です。

概要

ベイズ統計をAIコンペや実業務で使えたらいいなーと思い、独学で勉強しています。参考書から大枠のイメージを理解することはできましたが、じゃあ実際どう使えばいいんだい、というところが全く分かりませんでした。

何かハンズオン形式のチュートリアルみたいな教材がないかなと調べていたところ、pymc3のチュートリアルがとても良さそうだったので、こちらもとに学習を進めることにしました。

ただ、pymc3のチュートリアルは全て英語で記載されており、かつ統計的な専門知識もある程度必要になるため、理解は少し時間がかかりました。日本語の詳細な解説記事があればいいのになーと思い、頭の整理も兼ねて本記事を執筆しました。

対象とする課題

pymc3 Bayesian mediation analysis

実行環境

Google Colab

ケーススタディ解説

mediation analysisとは

媒介分析は、影響を与える変数(独立変数)と、影響を受ける変数(従属変数)との間を、他の変数(媒介変数)が介在しているようなモデルを検討する分析であるとされています1。つまり、X⇨Yの関係の間にmという別の変数が"媒介変数"として入り、X⇨m⇨Yという関係になるような状況を分析することを言います。
説明は参考文献2が分かりやすかったので引用させていただくと、具体例として「風が吹けば桶屋が儲かる」を挙げられていました。これはまさに風が吹く⇨桶屋が儲かるの間に多くの媒介変数が含まれているという状況かと思います。

問題設定

以下のような、X⇨Y(直接効果)およびX⇨M⇨Y(間接効果)の関係性を考えます。最終的な目的は、各変数(X, M)がそれぞれYにどの程度の影響を与えているのかを解析することです。

image.png

単純化して書くと、

M = a \times X \\
Y = c' \times X + b \times M

となります。ベイズを用いてa, b, c'を逆解析できれば、X, MがYに与える影響を解析できた、ということになります。
ただ、実際はこんなにも綺麗な線形の関数の上にデータが乗るはずはなく、ばらつき・ノイズが入ったデータになります。従って、以下のようにモデリングして考えていきます。

M \sim Normal(i_M + a \times X, \sigma_M) \\
Y \sim Normal(i_Y + c' \times X + b \times M, \sigma_Y)

ここでi, σはそれぞれ切片、標準偏差になります。Normalは正規分布を表しています。

データの作成

まずはモデリングのためのデータを準備します。今回は実データではなく、自前で作っています。ここで作成したデータを、後程ベイズを使って解析していきます。

def make_data():
    N = 75
    a, b, cprime = 0.5, 0.6, 0.3
    im, iy, σm, σy = 2.0, 0.0, 0.5, 0.5
    x = rng.normal(loc=0, scale=1, size=N)
    m = im + rng.normal(loc=a * x, scale=σm, size=N)
    y = iy + (cprime * x) + rng.normal(loc=b * m, scale=σy, size=N)
    print(f"True direct effect = {cprime}")
    print(f"True indirect effect = {a*b}")
    print(f"True total effect = {cprime+a*b}")
    return x, m, y


x, m, y = make_data()

sns.pairplot(DataFrame({"x": x, "m": m, "y": y}));

image.png

式で書くと以下のような正規分布になります。

M \sim Normal(2.0 + 0.5 \times X, 0.5) \\
Y \sim Normal(0.0 + 0.3 \times X + 0.6 \times M, 0.5)

平均0, 標準偏差1の正規分布からランダムに75個のXを生成し、上述の式に代入することで、X, M, Yのデータを作成しています。Xを生成しているコードは以下の部分にあたります。

x = rng.normal(loc=0, scale=1, size=N)

モデリング

ここからモデルを定義していきます。以下のコードで定義されています。

def mediation_model(x, m, y):
    with pm.Model() as model:
        # 変数の定義
        x = pm.Data("x", x)
        y = pm.Data("y", y)
        m = pm.Data("m", m)

        # 切片の事前分布 (正規分布)
        im = pm.Normal("im", mu=0, sigma=10)
        iy = pm.Normal("iy", mu=0, sigma=10)
        # 傾きの事前分布 (正規分布)
        a = pm.Normal("a", mu=0, sigma=10)
        b = pm.Normal("b", mu=0, sigma=10)
        cprime = pm.Normal("cprime", mu=0, sigma=10)
        # 標準偏差の事前分布 (半コーシー分布)
        σm = pm.HalfCauchy("σm", 1)
        σy = pm.HalfCauchy("σy", 1)

        # 尤度関数
        pm.Normal("m_likehood", mu=im + a * x, sigma=σm, observed=m)
        pm.Normal("y_likehood", mu=iy + b * m + cprime * x, sigma=σy, observed=y)

        # 間接的な効果(x→m→y)と全体の効果(x→m→y と x→y)を算出する
        indirect_effect = pm.Deterministic("indirect effect", a * b)
        total_effect = pm.Deterministic("total effect", a * b + cprime)

    return model

# モデルにデータ(x, m, y)を代入
model = mediation_model(x, m, y)

# グラフィカルモデルによる可視化
pm.model_to_graphviz(model)

image.png

観測される値はx, m ,y の3種類ですので、これらの観測値から分布に関するパラメータ(切片・傾き・標準偏差)の事後分布を算出する、という流れになります。
今回のモデリングでは、切片・傾きの事前分布として正規分布が選択されています。切片や傾きに関する情報はとくにないので、まぁそんなに大きな値は取らないだろうけどよくわからないからとりあえず標準偏差10(95%信頼区間: 0 ± 19.6)という少し広めの幅でとっているのだろうと思います。
また、標準偏差の事前分布は半コーシー分布が選択されています。これは標準偏差が非負であることに関連しています。半コーシー分布も非負でかつ0からピークアウトする形状ということもあり、標準偏差の事前分布によく利用されているようです34

事後分布のサンプリング

傾き・切片・標準偏差の事後分布のサンプリングをします。pymc3のデフォルトでは、MCMCの一種であるNUTSという手法でサンプリングを行っています。
(サンプリングの考え方については少し複雑ですが、こちらのスライドが分かりやすいです。)

with model:
    result = pm.sample(
        2000,
        tune=4000,
        chains=2,
        target_accept=0.9,
        random_seed=42,
        return_inferencedata=True,
        idata_kwargs={"dims": {"x": ["obs_id"], "m": ["obs_id"], "y": ["obs_id"]}},
    )

az.plot_trace(result);

image.png

右図は各パラメータの値を2000回サンプリングした結果を表示しています。これらのサンプリング結果を確率分布の図として表したのが左図(ヒストグラム化のようなイメージ)となり、これが事後分布になります。
例えばパラメータbの左図に着目すると、平均おおよそ0.7の事後分布になっていることがわかります。

この図から、計算が問題なく収束しているかを(定性的にですが)確認できます。
サンプリング時にchains=2というパラメータを設定しています。これは、各変数の事後分布に対するサンプリングを何回行うかという設定です。今回は2回と設定しているため、上図のグラフに波線と実線の2つの線が表れています。これら2つの確率分布を比較して、似た形状になっていることが、計算がきちんと収束していると判定できる一つの基準になります。また、例えば、ピークが複数あるような多峰分布となる事後分布に対しては局所解に陥る可能性があるため、 その確認にも有効です。(とはいえ2回だけだと同じ局所解に陥るという可能性もあるとは思いますが。。)
chain数についてはGelman-Rubinの収束診断などでも言及されています。

パラメータの可視化

次に同時事後分布を可視化します。これにより、各パラメータのおおよその関係性が見えるようになります。

az.plot_pair(
    result,
    marginals=True,
    point_estimate="median",
    figsize=(12, 12),
    scatter_kwargs={"alpha": 0.05},
    var_names=["a", "b", "cprime", "indirect effect", "total effect"],
);

image.png

例えば(b)-(indirect effect)の関係に着目すると、bが大きくなるほどindirect effectが大きい値をとりやすくなる(右上がりの分布)になっていることがわかります。これは、モデリングの際にindirect effectをa×bで定義していたことからも分かる通りです。
また、(b)-(cprime)の関係に着目すると、負の相関関係があることがわかります。目的変数yに対して、bは間接効果、cprimeは直接効果を表すパラメータでした。この結果は、bが大きくなればなるほど間接効果が大きくなるため、結果的に直接効果は小さくなる、といった関係性を表していると考えられます。

間接効果・直接効果の確認

ax = az.plot_posterior(
    result,
    var_names=["cprime", "indirect effect", "total effect"],
    ref_val=0,
    hdi_prob=0.95,
    figsize=(14, 4),
)
ax[0].set(title="direct effect");

image.png

上手のとおり、直接効果、間接効果、全体の効果の事後平均はそれぞれ0.28・0.49・0.77と算出されました。
直接効果は、xが1単位増加するごとに、直接効果x→yによりyが平均0.28増加することを意味します。
同様に、間接効果は、xが1単位増加するごとに、x→m→yの経路を経てyが平均0.49増加することを意味します。また、間接効果が0である確率は非常に小さいため、「間接効果がないとはいえない」、という結果になっていることがわかります。
また、全体の効果は0.77であり、これはxが1単位増加するごとに、直接経路と間接経路の両方を通じてyが0.77増加することを意味します。

全体効果の確認

最後に、別モデルを用いて、前節で示した全体の効果が妥当かどうかをチェックします。ここでは、x⇨yの直接効果のみしか存在しないとしてモデリングをし、その係数の事後分布が前節の全体の効果と一致するかどうかを確認します。
具体的には以下のような単純なモデルを作成し、係数cの事後分布を全体の効果と比較します。

Yi \sim Normal(i_Y + c \times Xi, \sigma_Y)

# モデリング
with pm.Model() as total_effect_model:
    _x = pm.Data("_x", x)
    iy = pm.Normal("iy", mu=0, sigma=1)
    c = pm.Normal("c", mu=0, sigma=1)
    σy = pm.HalfCauchy("σy", 1)
    μy = iy + c * _x
    _y = pm.Normal("_y", mu=μy, sd=σy, observed=y)

# サンプリング
with total_effect_model:
    total_effect_result = pm.sample(
        2000,
        tune=4000,
        chains=2,
        target_accept=0.9,
        random_seed=42,
        return_inferencedata=True,
        idata_kwargs={"dims": {"x": ["obs_id"], "y": ["obs_id"]}},
    )

# 可視化
fig, ax = plt.subplots(figsize=(14, 4))
az.plot_posterior(
    total_effect_result, var_names=["c"], point_estimate=None, hdi_prob="hide", c="r", lw=4, ax=ax
)
az.plot_posterior(
    result, var_names=["total effect"], point_estimate=None, hdi_prob="hide", c="k", lw=4, ax=ax
);

image.png

赤線は係数cの事後分布、黒線は前節でサンプリングした全体の効果の事後分布です。これらの分布がほぼ同じになっていることから、媒介分析を用いて全体の効果を算出できたと考えられます。

最後に

今回はpymc3の媒介分析ケーススタディを解説してみました。

もし内容に誤りがありましたら、コメントにてご指摘いただけますと大変有難いです。

参考文献


  1. 伊藤 萌恵, "HADによる媒介分析" VOL.34 (2019).  

  2. https://note.com/i_partners/n/nc08bc6e3d5de 

  3. 小森政嗣, "これからベイズ統計を使ってみたい人に" 日本音響学会誌 75, 6 (2019). https://www.jstage.jst.go.jp/article/jasj/75/6/75_351/_pdf 

  4. https://www.slideshare.net/hoxo_m/ss-59418886 

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6