LoginSignup
12
11

More than 5 years have passed since last update.

PRML§3.3.1 ベイズ線形回帰によるパラメータ分布の収束図を再現する

Last updated at Posted at 2015-06-10

社内勉強会でぼちぼち進めているPRML、カラーの図版があると真面目に再現したくなります。今回は3.3.1のパラメータの分布の事後確率が収束する様子の図版を再現しました。この図は

  1. 直近の観測データ1つによるパラメータの尤度
  2. パラメータの事前/事後確率分布
  3. 事前/事後確率分布から得たパラメータによるあてはめ結果

の3つからなり、事後確率が収束していく様子が見られて面白いです。

前提

  • 1次元の入力変数$x$と、1次元の目標変数$t$を考える
  • $y(x, w) = w_0 + w_1x$ の線形モデルにあてはめを行なう
  • モデルのパラメータ2つの分布は、2変量ガウス分布とする
  • ノイズはガウスノイズとする

必要な計算部品

観測データから計画行列$\Phi$の作成

def design_matrix(x):
    return np.array([[1, xi] for xi in x])

N回データを観測した後の平均 $m_N = \beta S_N\Phi^{T}t$ を求める関数

def calc_mn(alpha, beta, x, t):
    Phi = design_matrix(x)
    Sn = calc_Sn(alpha, beta, x)
    return beta * Sn.dot(Phi.T).dot(t)

N回データを観測した後の共分散 $S_N = (\alpha I + \beta\Phi^{T}\Phi)^{-1}$ を求める関数

I = np.identity(2)

def calc_Sn(alpha, beta, x):
    Phi = design_matrix(x)
    return np.linalg.inv(alpha*I + beta*Phi.T.dot(Phi))

観測データの尤度

観測データは精度$\beta$のガウスノイズが乗る事から、対数尤度関数

logL(w) = -\frac{\beta}{2}(t-w^{T}\phi(x))^2 + cons

を利用する。対数尤度関数でプロットしたい範囲の$w$を求める。

def calc_likelifood(beta, t, x, w):
    """
    観測値1つの対数尤度を求める
    """
    w = np.array(w)
    phi_x = np.array([1, x])
    return -1 * beta / 2 * (t - w.T.dot(phi_x))**2

def plot_likelifood(beta, t, x, title='', ax=None):
    """
    観測値の尤度のプロット
    """
    w0 = np.linspace(-1, 1, 100)
    w1 = np.linspace(-1, 1, 100)
    W0,W1 = np.meshgrid(w0, w1)
    L = []
    for w0i in w0:
        L.append([calc_likelifood(beta, t, x, [w0i, w1i]) for w1i in w1])

    ax.pcolor(W0, W1, np.array(L).T, cmap=plt.cm.jet, vmax=0, vmin=-1)
    ax.set_xlabel('$w_0$')
    ax.set_ylabel('$w_1$')
    ax.set_title(title)

事後確率

上記の式で求めた $m_N$と$S_N$をパラメータとした2変量正規分布が$w$の事後確率分布となる、$m_0$と$S_0$の場合は事前確率分布。

def plot_probability(mean, cov, title='', ax=None):
    """
    確率分布(2変量ガウス)のプロット
    """
    w0 = np.linspace(-1, 1, 100)
    w1 = np.linspace(-1, 1, 100)
    W0,W1 = np.meshgrid(w0, w1)
    P = []
    for w0i in w0:
        P.append([scipy.stats.multivariate_normal.pdf([w0i,w1i], mean, cov) for w1i in w1])

    ax.pcolor(W0, W1, np.array(P).T, cmap=plt.cm.jet)
    ax.set_xlabel('$w_0$')
    ax.set_ylabel('$w_1$')
    ax.set_title(title)

データ空間

上記の式で求めた $m_N$と$S_N$をパラメータとした2変量正規分布から$w$を6回サンプリングして、プロットする。

結果

データサンプル前
index.png

1回データサンプル後
index2.png

8回データサンプル後
index3.png

できました。全てのコードはgithubにもアップしてあります。
https://github.com/hagino3000/public-ipynb/blob/master/PRML/PRML%203.3.ipynb

12
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
11