LoginSignup
16
15

More than 1 year has passed since last update.

ロジスティック回帰はベイズ推定すると何が変わるのかを可視化してみる

Last updated at Posted at 2022-02-02

ベイズ推定と予測

ベイズ推定ではモデルのパラメータ$\theta$も確率変数であるとみなし、データ$X$を観測したときにこれの事後分布$P(\theta|X)$を推定したのちに新規データ$X'$に対して以下のように予測を行います。

P(X'|X) = \int d\theta P(X'|\theta)P(\theta|X)

非ベイズな手法では単純に点推定したパラメータ $\theta^*$ に対して $P(X'|\theta^*)$ とするのに対し、ベイズでは上記のようにパラメータの事後分布について平均をとってパラメータに依存しない式にしているのが特徴です。ロジスティック回帰においてこれが果たしてどのような効果を持つのかが気になったので可視化してみます。

通常のロジスティック回帰

まずは単純な点推定版を試してみます。以下の方法でテストデータを作成します。

%matplotlib inline
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import random
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression

X, y = make_classification(
    n_features=2,
    n_informative=2,
    n_clusters_per_class=1,
    n_redundant=0,
    n_repeated=0,
    class_sep=0.5,
    random_state=1,
)

X = jnp.array(X)
y = jnp.array(y)

plt.figure(figsize=(8, 8))
mask0 = y==0
plt.scatter(X[mask0,0], X[mask0, 1], label="0")
mask1 = y==1
plt.scatter(X[mask1,0], X[mask1, 1], label="1")
plt.legend()

d52d3a37-1976-4f68-859c-557f3b7ea23f.png

このデータに対してロジスティック回帰で予測モデルを作成し、グリッド状の各座標においてクラス1になる確率を求めて可視化してみます1

# 予測モデル作成/学習
lr_model = LogisticRegression().fit(X, y)

# グリッドデータ作成 
xx = jnp.linspace(-3, 3, 200)
yy = jnp.linspace(-3, 3, 200)
xx, yy = jnp.meshgrid(xx, yy)
Xfull = jnp.c_[xx.ravel(), yy.ravel()]

# 各点に対して確率を求める
p_pred = lr_model.predict_proba(Xfull)[:, 1]

# 可視化
plt.figure(figsize=(8, 8))
plt.imshow(p_pred .reshape(200, 200), extent=(-3, 3, -3, 3), origin="lower", cmap="Greys")
mask0 = y==0
plt.scatter(X[mask0,0], X[mask0, 1])
mask1 = y==1
plt.scatter(X[mask1,0], X[mask1, 1])
for p in (0.1, 0.3, 0.5, 0.7, 0.9):
    mask = (p_pred > p - 0.005) & (p_pred < p + 0.005)
    plt.plot(Xfull[mask, 0], Xfull[mask, 1], "-", lw=0.5, label=f"{p=}")
plt.legend()

879dba29-7a21-4144-be69-d18d73b35942.png

上記の図では黒いほどp=1に近いことを表します。また確率pが同じ値の線も描いています。いずれも直線になります。

ベイジアンロジスティック回帰

お待ちかねのベイズ推定です。NumPyroを使用してHamilton Monte Carloでサンプリングします。

import numpyro as pyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

def model(X, y=None):
    w = pyro.sample("w", dist.MultivariateNormal(jnp.zeros(2), 5.0 * jnp.eye(2)))
    b = pyro.sample("b", dist.Normal(0.0, 5.0))
    with pyro.plate("data", size=len(X)):
        z = pyro.deterministic("z", jnp.dot(X, w) + b)
        pyro.sample("y", dist.BernoulliLogits(z), obs=y)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, X, y)
mcmc.print_summary()

>>>

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b     -0.20      0.76     -0.20     -1.33      1.14    479.34      1.00
      w[0]     -7.82      1.30     -7.75     -9.72     -5.45    487.29      1.01
      w[1]      0.85      1.09      0.87     -1.05      2.52    518.68      1.00

ちゃんと収束しているようです。予測は以下のように行います。

# サンプルから予測分布関数を作成する
posterior_f = Predictive(model, mcmc.get_samples())
# グリッドデータに対して線形予測子の予測分布のサンプルを得る
posterior_z = posterior_f(rng_key, Xfull)["z"]
# シグモイド関数で確率に変換し、サンプルの平均をとる
p_pred = (1.0 / (1.0 + jnp.exp(-posterior_z))).mean(0)

# 可視化
plt.figure(figsize=(8, 8))
plt.imshow(p_pred .reshape(200, 200), extent=(-3, 3, -3, 3), origin="lower", cmap="Greys")
mask0 = y==0
plt.scatter(X[mask0,0], X[mask0, 1])
mask1 = y==1
plt.scatter(X[mask1,0], X[mask1, 1])
for p in (0.1, 0.3, 0.5, 0.7, 0.9):
    mask = (p_pred > p - 0.005) & (p_pred < p + 0.005)
    plt.plot(Xfull[mask, 0], Xfull[mask, 1], "-", lw=0.5, label=f"{p=}")
plt.legend()

ce626631-4829-4347-b76a-354c6a70b550.png

p=0.5は点推定の時と同じく直線ですが、それ以外の場合は曲線になっているではありませんか。

何が起きているのか?

BishopのPRML2のChapter10にて説明されています。次の図10.13の右側のように複数の直線が事後分布から生成され、これらを出現確率で平均をとるとp=0.5以外のケースについては曲線になるようです。

image.png

まとめ

ロジスティック回帰のパラメータをベイズ推定し、予測分布を作成して利用するときにp=0.5では通常の点推定と全く同じ結果で意味がありません。しかし、p=0.9などの偏った値を利用して実際のクラスラベルを推定するときには非線形な決定境界ができ、ベイズ推定の真価を発揮します。ここでは単純なロジスティック回帰を扱いましたが、ニューラルネットワークによる識別モデルも最後の出力層はロジスティック回帰と同じですので同じ議論ができると思います。

16
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
15