Julia
Stan
JuliaDay 15

Stan.jl を動かしてみた

概要

  • julialang で Stan.jl のサンプルを動かします
  • cmdstan導入から
  • julia v0.6 を前提にします
  • cmdstan入りdocker image
  • ubuntu 16.04を前提にします

cmdstan導入

make: clang++: Command not found
make: *** [bin/cmdstan/stanc.o] Error 127

cmdstan v2.17.0 以降でclang++が見つからないというエラーが出たとき

/path/to/cmdstan/stan/lib/stan_math/make/detect_cc の13行目をコメントアウトしました

mkdir -p /opt/cmdstan
curl -L https://github.com/stan-dev/cmdstan/releases/download/v2.17.0/cmdstan-2.17.0.tar.gz | tar -z -x -C /opt/cmdstan --strip-components=1 -f -
sed -e "13s/^/#/" -i.bak /opt/cmdstan/stan/lib/stan_math/make/detect_cc
cd /opt/cmdstan && make build
export CMDSTAN_HOME="/opt/cmdstan"

cmdstan のサンプルを動かしてみる

cd /opt/cmdstan/
make examples/bernoulli/bernoulli
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R
bin/stansummary output.csv

Stan.jl 導入

Stan.jl の walkthrough を追いかけます

Pkg.add("Stan.jl")
Pkg.add("Mamba.jl")

using Stan, Mamba

# bernouli の stanmodel を定義します
const bernoullistanmodel = "
data { 
  int<lower=0> N; 
  int<lower=0,upper=1> y[N];
} 
parameters {
  real<lower=0,upper=1> theta;
} 
model {
  theta ~ beta(1,1);
    y ~ bernoulli(theta);
}
"

# Stanmodel() で、stanmodel オブジェクトを作ります
stanmodel = Stanmodel(name="bernoulli", model=bernoullistanmodel);
stanmodel |> display

# input data を定義します
const bernoullidata = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])

# stan() 実行! ProDir オプションは外しています
rc, sim1 = stan(stanmodel, [bernoullidata],  CmdStanDir=CMDSTAN_HOME)

# rc ( return code ) が成功 (rc==0) であれば、Mamba.jl の describe() 関数を実行
# walkthrough のコードのままでは getindex() の error がおこる
if rc == 0
  rc, sim1 = stan(stanmodel, [bernoullidata], CmdStanDir=CMDSTAN_HOME)
  println("Subset Sampler Output")
  sim = sim1[1:1000, ["lp__", "theta", "accept_stat__"], :]
  describe(sim)
end

# sim1 の中身を dump してみます
dump(sim1)

# Mamba.jl から Gadfly.jl によるグラフ表示を描画します
println("Brooks, Gelman and Rubin Convergence Diagnostic")
try
  gelmandiag(sim, mpsrf=true, transform=true) |> display
catch e
  #println(e)
  gelmandiag(sim, mpsrf=false, transform=true) |> display
end

# log{T <: Number}(x::AbstractArray{T}) に関する warnning がでます v0.6
# ここを変更かな
# https://github.com/brian-j-smith/Mamba.jl/blob/ec76a8410a92893a53069607c591aef0868c03e6/src/output/chains.jl#L242)

println()

println("Geweke Convergence Diagnostic")
gewekediag(sim) |> display
println()

println("Highest Posterior Density Intervals")
hpd(sim) |> display
println()

println("Cross-Correlations")
cor(sim) |> display
println()

println("Lag-Autocorrelations")
autocor(sim) |> display
println()

# シミュレーション結果を plot します
p = plot(sim, [:trace, :mean, :density, :autocor], legend=true);
draw(p, ncol=4, filename="summaryplot", fmt=:svg) # svg で出力
draw(p, ncol=4, filename="summaryplot", fmt=:pdf) # pdf で出力
draw(p, ncol=4) # X に描画