LoginSignup
3
3

概要

WAICはAICを改良して様々な場合でも利用できるようにした情報量基準です。詳しい理論的な背景は考案者である渡辺澄夫氏の以下のWebページなどを参照してください。(2024年には公開が終わるようなので早めに保存したほうがいいかも)
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waic2011.html

この記事ではJaxをベースに開発された確率プログラミングのフレームワークであるNumPyroを用いてこのWAICを求める方法について説明します。(WAICの理論やNumPyro自体の詳細な解説はしません)

コードはこちらで公開しています。
https://github.com/lucidfrontier45/numpyro_linear_regression_waic

環境

  • Python 3.11
  • numpyroとjaxtypingをpipなどでインストール

モデル定義とパラメータ推定

ここでは以下のような単純な線形モデルを例にとります。

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jaxtyping import Array, Float32


def linear_model(X: Float32[Array, "N D"], y: Float32[Array, " N"] | None):
    N, D = X.shape

    with numpyro.plate("dimension", D):
        w = numpyro.sample("w", dist.Normal(0, 1))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))

    z = jnp.dot(X, w)  # type: ignore

    with numpyro.plate("data", N):
        y = numpyro.sample("y", dist.Normal(z, sigma), obs=y)  # type: ignore

Hamiltonモンテカルロ法を用いたパラメータのサンプリングは以下のように行います。

import jax
import numpy as np
from jaxtyping import Array, Float, Float32
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_likelihood
from scipy.special import logsumexp


def run_mcmc(
    model,
    X: Float32[Array, "N D"],
    y: Float32[Array, " N"],
    num_warmup: int = 1000,
    num_samples: int = 1000,
    num_chains: int = 1,
    seed: int = 0,
):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=False,
    )
    rng_key = jax.random.PRNGKey(seed)
    mcmc.run(rng_key, X, y)
    return mcmc

この関数を実行することでNumPyroのMCMC構造体が得られ、パラメータの事後分布からのサンプルが得られます。

WAICの計算

いよいよWAICの計算です。上記の渡辺氏の資料によるとWAICは以下のように計算することができます。

\begin{align}
\mathrm{WAIC} &= T + V \\
T &= -\frac{1}{N}\sum_n^N \log \mathrm{E}_w\left[P(Z_n|w)\right] \\
V &= \frac{1}{N}\sum_n^N \left(\mathrm{E}_w\left[\log{P(Z_n|w)}^2\right] - \mathrm{E}_w\left[\log{P(Z_n|w)}\right]^2 \right) \\
Z_n &= \left\{X_n, y_n \right\}
\end{align}

この計算をするためには事後分布のパラメータのサンプルそれぞれについて尤度$P(Z_n|w)$あるいは大数尤度$\log P(Z_n|w)$を求める必要がありますが、NumPyroではnumpyro.infer.util.log_likelihoodを使用することでこれを計算することができます。

logp = log_likelihood(model, posterior_samples, X, y)["y"]

これを使用して$T,V$を計算していきます。基本的には確率分布に対する期待値をサンプル平均で置き換えます。

まず$T$ですが、式を変形すると以下のようになります。

\begin{align}
T = -\frac{1}{N}\sum_n^N \log \frac{1}{M}\sum_m^M \exp \log P(Z_n|w_m) \\
\end{align}

$1/M$が途中にありますが、基本的にはいわゆるlogsumexpと呼ばれる計算です。大数尤度のlogsumexpに定数項$-\log M$を足すという方式でもいいですが、SciPyのlogsumexp関数はスケーリング項bを受け付けるのでこれを使用すればいいです。

$V$についてはnについての和の中身が二乗の平均引く平均の二乗で分散の定義そのままですのでパラメータのサンプル軸$m$方向にで$\log P(Z_n|w_m)$の分散を計算し、データのサンプル軸$n$方向に平均を計算すればいいです。

\begin{align}
V = \frac{1}{N}\sum_n^N V_{w}\left[ \log P(Z_n|w_m) \right]
\end{align}

まとめると以下のようになります。

def calc_waic(logp: Float[np.ndarray, "M N"]) -> float:
    M = logp.shape[0]  # number of posterior samples
    T = -logsumexp(logp, axis=0, b=1.0 / M).mean()
    V = logp.var(axis=0).mean()
    return T + V


def evaluate_model(
    model,
    X: Float32[Array, "N D"],
    y: Float32[Array, " N"],
    posterior_samples: dict[str, Float32[Array, "M _*"]],
):
    logp = log_likelihood(model, posterior_samples, X, y)["y"]
    return calc_waic(jax.device_get(logp))

実験

# テストデータ準備
# 線形結合の係数wはあえて無駄な次元を2つ付け、モデル選択の検証に用いる 
w = np.array([3.5, -1.5,  0.0, 0.0])
sigma = 0.5

D = len(w)
N = 100
np.random.seed(0)
X_ = np.random.randn(N, D)
y_ = np.dot(X_, w) + np.random.randn(N) * sigma

X = jax.device_put(X_)
y = jax.device_put(y_)

# パラメータ推定とWAICの計算
mcmc = run_mcmc(linear_model, X, y)
waic = evaluate_model(linear_model, X, y, mcmc.get_samples())
print(waic)
> 0.7954699710394434

# arvizの実装と比較
# どうやらデータ数Nで割り算されていないようであるが、その分を除けば一致している
import arviz
arviz.waic(mcmc, scale="negative_log")
>            Estimate       SE
> -elpd_waic    79.55     7.12
> p_waic         4.73        -

# モデル選択
# 後ろ2つのダミー次元を除いた2次元が最適であると正しく求まった
for i in range(4):
    XX = X[:, :D-i]
    mcmc = run_mcmc(linear_model, XX, y)
    waic = evaluate_model(linear_model, XX, y, mcmc.get_samples())
    print(f"WAIC for {D-i} dimensions: {waic}")
> WAIC for 4 dimensions: 0.7954699710394434
> WAIC for 3 dimensions: 0.7834370291883676
> WAIC for 2 dimensions: 0.7765665572323642
> WAIC for 1 dimensions: 1.9153631023793407

まとめ

WAICはパラメータの事後分布からの各サンプルに対する対数尤度$\log p$が求まればcalc_waicのように非常に簡単に求めることができると分かりました。機械学習のモデルの訓練をベイズ推定で行う場合、ハイパーパラメータの最適化時は交差検証を使用せずにWAICを利用することで大幅に計算時間を削減できそうです。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3