LoginSignup
10

More than 5 years have passed since last update.

[MCMC] pystanでWAICを計算する

Posted at

はじめに

最近ベイズ統計の勉強を始めてMCMCライブラリとしてStanを使おうと思ったのですが、みんなRと組み合わせてやろうみたいな教科書ばかりでpython関連の情報が少なかったので久しぶりに投稿しました。

環境

  • Ubuntu 16.10 64bit
  • Anaconda 4.4.0
  • python 2.7.13
  • pystan 2.14.0.0

pystanのインストール

公式のドキュメントによれば
conda install pystan
で終わりなのだが、自分の場合はこれだけだとモデルをコンパイルする際に
undefined symbol: _ZTVNSt7__cxx1118basic_stringstreamIcSt11char_traitsIcESaIcEEE
エラーが出たので追加でgccもインストールした。
conda install gcc

テスト用のモデルとデータの作成

豊田秀樹『実践ベイズモデリング』から拝借

import pystan
import numpy as np

model_code = '''
data{
    int<lower=0> N;
    vector[N] x;
    vector[N] y;
}
parameters{
    real b;
    real a;
    real<lower=0> sigma;
}
transformed parameters{
    real mu[N];
    for(i in 1:N){
        mu[i] = b + a*x[i];
    }
}
model{
    for(i in 1:N){
        y[i] ~ normal(mu[i], sigma);
    }
}
generated quantities{
    vector[N] log_lik;
    for(i in 1:N){
        log_lik[i] = normal_lpdf(y[i] | mu[i], sigma);
    }
}
'''

sm = pystan.StanModel(model_code=model_code)

x = [150.5,160.2,148,177.1,162,148.2,163.8,166.9,164.9,154.3,176.1,162.7,150.5,131.4,171.5,157.5,157.8,169.3,167.9,165.1,
169,167.4,158.9,134.1,165.4,157.3,156.1,140.4,152.3,163,174.3,156.8,162.7,157.4,141.5,153,153.3,157.3,171.2,167.2,
156,155,166.4,164.7,149.7,149.5,162.4,167.2,156.7,168.6,162.8,150.7,162.1,144.4,175.2,181.8,153.6,145.5,164.8,156.4,
186.8,157.5,166.3,158.3,149.1,160.3,136.3,175.6,159.8,184.1,163.7,149.5,165.3,146.8,143,161.5,152.7,158,158.9,150.9,
151.2,156.4,172.1,139.7,165.1,162,170.8,154.3,162.4,161.2,151.5,172.5,171.9,166.4,177,164.7,142.7,151.1,143.3,152.3]
y = [158.7,163.3,156.6,164.1,158.4,175.4,168,169.4,165.7,174.8,158.5,159.8,173,158.4,161.5,160.3,160.8,161,166.5,161.8,
159.5,172.4,161.5,161.7,162.3,168,162.5,162.7,158.2,160.7,163.4,158.9,166.7,152.4,165.1,152.2,160.9,159.3,158.4,162.6,
149.6,171.2,151.3,159.8,155.2,157.7,177.6,163.1,154,151.5,166.2,162.9,160.8,156.5,152.6,155.5,170,158.7,153.3,176.1,
166,161.3,170.4,169.2,158.7,178.4,161.2,153,162,164.5,179.2,163.7,166.2,162.5,160.7,162.8,168.5,177.5,170.2,171.5,
154.4,169.9,164.5,152.7,166.6,161.9,173.3,157.6,160,156.5,161.8,165.8,157.9,168.8,154.5,155.7,173.1,155.9,165.9,160.3]
data = {'N':100, 'x':x, 'y':y}

fit = sm.sampling(data=data, iter=2000, chains=4)

WAIC の計算

WAICとはWidely Applicable information criterion (Watanabe, 2010)の略でベイジアンモデルにおいて用いられる情報量基準の1つである。計算式は

WAIC = -2lppd_{waic} + 2p_{waic} \\
lppd_{waic} \approx \sum_{i = 0}^{N}log \bigl\{ \frac{1}{T} \sum_{t=1}^{T}f(x_{i} | \theta^{(t)}) \bigr\} \\
p_{waic} = \sum_{i=1}^N V_i [log f(x_i | \theta^{(t)})]

よって以下のような関数を定義する

def waic(fit):
    log_lik = fit.extract()['log_lik']
    lppd = np.log(np.exp(log_lik).mean(axis=0)).sum()
    p_waic = np.var(log_lik, axis=0).sum()
    waic = -2*lppd + 2*p_waic
    return round(waic, 3)

これを使うことで

waic(fit)
>>>669.028

WAICが計算できた!

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10