LoginSignup
11
6

More than 3 years have passed since last update.

【Pyro】確率的プログラミング言語Pyroによる統計モデリング② ~単回帰モデル~

Last updated at Posted at 2020-04-20

はじめに

前回の記事では,確率的プログラミング言語Pyroの概要と,Pyroを用いた標本サンプリングの仕方を紹介しました.
今回は,Pyroを用いて「実践Data Scienceシリーズ RとStanではじめるベイズ統計モデリングによるデータ分析入門」(以下,参考書と呼びます.)にある最初の例題(第3章第2部 単回帰モデル)をもとに,Pyroでの実装を行います.
なお,実装は 3-2 単回帰モデル(Google Colaboratory) にて公開しています.

利用データ & 例題紹介

今回扱うデータ(書籍3-2参照)は ビールの売り上げと気温の関係 です.統計モデリングにより,売り上げを気温によって説明するモデルを構築します.
実データのサンプルおよび散布図は以下の通りです.

# データ読み込み
beer_sales_2 = pd.read_csv(
    'https://raw.githubusercontent.com/logics-of-blue/book-r-stan-bayesian-model-intro/master/book-data/3-2-1-beer-sales-2.csv'
) # shape = (100, 2)
beer_sales_2.head()
sales temperature
0 41.68 13.7
1 110.99 24
2 65.32 21.5
3 72.64 13.4
4 76.54 28.9
# 可視化
beer_sales_2.plot.scatter('temperature', 'sales')
plt.title("ビールの売り上げと気温の関係")

image.png
なんとなく正の相関がありそうですね.
単回帰モデルによって, sale(売り上げ)をtemperature(気温)で説明するモデルを記述すると,以下のようになります.
$sales_i \sim \rm{Normal}(Intercept + \beta * temperature_i, \sigma^2)$
ベイズ統計モデリングでは,上式の中で推定したいパラメタである$Intercept$, $\beta$, $\sigma^2$ を確率変数として扱い(これらを潜在変数と呼ぶ),観測データを用いて事後分布を推定します.

Pyroによる実装

Pyroでは,以下の手順でモデル記述&パラメタ推定を行います.
1. modelメソッド記述
2. 事後分布の推定
以下では,構成要素を一つずつ説明していきます.

modelメソッド記述

Pyroでは,仮定した統計モデルがデータを生成するプロセスを一つのメソッド内に記述します.通例,メソッドの名前はmodelと名付けます.その際,以下の点を満たすように気をつきます.
* 前回の記事で紹介したpyro.samplepyro.plateを用いてデータ生成プロセスを記述する.
* 引数として観測データを受け,データ生成プロセスの中で観測データがどこの部分にあたるかを記述する.
以下はその実装です.

import torch
import pyro
import pyro.distributions as dist


def model(temperature: torch.Tensor, sales: torch.Tensor=None):
    intercept = pyro.sample("intercept", dist.Normal(0, 1000))
    beta = pyro.sample("beta", dist.Normal(0, 1000))
    sigma2 = pyro.sample("sigma2", dist.Uniform(0, 1000))
    with pyro.plate("plate", size=temperature.size(0)):
        y_ = intercept + beta * temperature
        y = pyro.sample("y", dist.Normal(y_, sigma2), obs=sales)
    return y

model内でpyro.sampleによって定義された4つの変数( $Intercept$, $\beta$, $\sigma^2$, $y$ )は全て確率変数として扱われます.そして,確率変数のうち観測データと対応していないもの( $y$ 以外)が潜在変数であり,後に事後分布を推定する対象となります.

事後分布の推定

統計モデルをmodelメソッドを記述したら,潜在変数の事後分布を推定します.
推定方法には,主に以下の2つが提供されています.
1. MCMC
2. 変分推論

MCMC

推定に必要なメソッドはpyro.inferに提供されています.
NUTSカーネルを用いた推定を行うには以下のように記述します.

import pyro.infer as infer

# 説明変数・観測変数をtorch.Tensorに変換する
temperature = torch.Tensor(beer_sales_2.temperature)
sales = torch.Tensor(beer_sales_2.sales)

# 事後分布の推定
nuts_kernel = infer.NUTS(model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True)
mcmc = infer.MCMC(nuts_kernel, num_samples=3000, warmup_steps=200, num_chains=3)
mcmc.run(temperature, sales)
mean std median 5.0% 95.0% n_eff r_hat
intercept 21.00 6.03 20.99 11.01 30.80 2869.59 1.00
beta 2.47 0.29 2.47 1.98 2.94 2866.62 1.00
sigma2 17.06 1.22 16.98 15.02 19.00 4014.61 1.00

modelメソッド内の確率変数のうち,推定対象である切片(intercept),傾き(beta),分散(sigma2)の事後分布が推定されました.切片はおよそ21,傾きはおよそ2.5となり,参考書籍で示された結果と大方一致しています.

変分推論による推定

Pyroでは変分推論による推定も行えます.変分推論では,求めたい潜在変数$Z$の事後分布 $p(Z|X)$ を別の関数 $q(Z)$ で近似することを考えます.Pyroではこの $q(Z)$ に当たるメソッド(guide)を記述すれば,$p(Z|X)$ を最も良く近似する $q(Z)$ を求めることができます.
例題に置き換えながら状況を整理しましょう.modelメソッドで記述されたデータ生成プロセスを振り返ると,この問題で推定したい潜在変数はintercept, beta, sigma2の3つであることがわかります.これら3つの潜在変数が生成される過程をguideメソッドに記述すれば良いわけです.
ただし,guideを適当に書いては $q(Z)$ を $p(Z|X)$ に近づけていくことができません.Pyroではpyro.paramによって宣言されたパラメタを,$q(Z)$ と $p(Z|X)$ の分布をより近づける方向に更新することにより,変分推論を実現します.その際,Pytorchの自動微分と勾配降下を使います.
以上を踏まえ,guideメソッドの書き方を以下にまとめます.
* pyro.paramによって更新対象パラメタと初期値を宣言しておく.
* guideメソッド内でそのパラメタを読み込む.制約条件はdist.constraintsで指定する.
* dist.確率分布(更新対象パラメタ)およびpyro.sampleで潜在変数を生成する.
* 引数はmodelメソッドと同じでなければならない.

以下は,この問題に対するguideの実装の一例です.今回は3つの潜在変数を独立と仮定した上で,関数形としてはデルタ分布を選びました.

# 変数の宣言と初期値の設定
pyro.param("intercept_q", torch.tensor(0.))
pyro.param("beta_q", torch.tensor(0.))
pyro.param("sigma2_q", torch.tensor(10.))

# q(Z)の実装
def guide(temperature, sales=None):
    intercept_q = pyro.param("intercept_q")
    beta_q = pyro.param("beta_q")
    sigma2_q = pyro.param("sigma2_q", constraint=dist.constraints.positive) 
    intercept = pyro.sample("intercept", dist.Delta(intercept_q))
    beta = pyro.sample("beta", dist.Delta(beta_q))
    sigma2 = pyro.sample("sigma2", dist.Delta(sigma2_q))

例えばデルタ分布ではなく正規分布にするなら,各潜在変数に対してpyro.paramで平均($\mu$)と分散($\sigma^2$)を定義した上で,dist.Normal($\mu$,$\sigma^2$)からpyro.sampleを使って生成するように書き換えれば良いだけです.このように,guide部分を変更することにより柔軟に事後分布の推定方法を変更できます.
ちなみに,コード例のように全ての潜在変数を独立なデルタ分布とするなら,
guide = infer.autoguide.guides.AutoDelta(model)
と書いても同じです.infer.autogide.guidesには自動でguideメソッドを作ってくれる便利な機能が提供されています.(http://docs.pyro.ai/en/0.2.1-release/contrib.autoguide.html)

model, guideメソッドを書いたら事後分布を簡単に求められます.

# 推定対象のパラメタの値を1エポックごとに保持しておくためのメソッド
def update_param_dict(param_dict):
    for name in pyro.get_param_store():
        param_dict[name].append(pyro.param(name).item())

# 変分推論の実行
adam_params = {"lr": 1e-1, "betas": (0.95, 0.999)}
oprimizer = pyro.optim.Adam(adam_params)

svi = infer.SVI(model, guide, oprimizer, loss=infer.JitTrace_ELBO(),)

n_steps = 10000
losses = []
param_dict = defaultdict(list)

for step in tqdm(range(n_steps)):
    loss = svi.step(temperature, sales,)
    update_param_dict(param_dict)
    losses.append(loss)

for name in pyro.get_param_store():
    print("{}: {}".format(name, pyro.param(name)))
    intercept_q: 21.170936584472656
    beta_q: 2.459129810333252
    sigma2_q: 16.684457778930664

切片約21, 傾き約2.5, 分散約17ということで,妥当な推定結果が出ました.下図から,損失や各更新対象パラメタも収束していることがわかります.

download-1.pngdownload.png

GPUによる高速化

変分推論によっても事後分布の推定が行えることがわかったところで,Pyroの最大の長所であるGPU利用によるスケーラビリティについても考えていきたいと思います.
PyroでGPUを利用する最も簡単な方法は,コードの冒頭に
torch.set_default_tensor_type(torch.cuda.FloatTensor)
と書くことで,生成されるTensorを初めから全てGPUに乗せておくことです.
さて,例題ではサンプル数が100でしたが,これが1000, 10000, ...と増えていったら変分推論にかかる時間はどのように増えるでしょうか.
ここでは100個のサンプルから推定された切片,傾き,分散が仮に真の値であったとして,その真の分布(仮)から1000個, 10000個, ...のサンプルを生成し,そのサンプルを観測データとして変分推論を行った時にかかる時間をCPUとGPUで比較しました.

サンプル数 CPU(秒) GPU(秒)
$10^2$ 1.34 3.66
$10^3$ 1.4 1.95
$10^4$ 1.51 1.95
$10^5$ 4.2 1.96
$10^6$ 30.41 2.26
$10^7$ 365.36 11.41
$10^8$ 3552.44 104.62

表から分かる通り,サンプル数が1万以下の領域では並列化の恩恵が受けられず,CPUの方が速くなっています.しかし,サンプル数が10万を超えたあたりから差が開き始め,$10^8$個では30倍以上の開きがあるという結果になりました.GPUを使えば非常に大きなデータを扱う時にもストレスなく統計モデリングが行えるということが伺えます.

まとめ

今回はPyroを用いて単回帰モデルによる分析を行う方法を紹介しました.
仮定した統計モデルがデータを生成するプロセスをmodelメソッドに記述し,MCMCまたは変分推論によって事後分布を推定できます.
また,GPUを利用した高速化に関する実験を行い,Pyroによってスケーラブルな分析が行えることを確認しました.

11
6
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
6