55
45

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Stan Advent Calendar 2016

Day 23

【PyStan】Graphical LassoをStanでやってみる。

Last updated at Posted at 2016-12-23

こんにちは、久しぶりにブログを書く@kenmatsu4です。
Stan Advent Calendarの23日目の記事を書きました。

今回のブログでは、Graphical Lassoという、L1正則化をかけた精度行列(分散共分散行列の逆行列)を推定する手法をStanを用いてやってみようというものです。コードの全文はGitHubにアップロードしています。

1. テスト用データの生成

まず、多変量正規分布に従う乱数を生成します。
今回は下記のような平均、分散をもつ6次元のデータを300個生成します。
そして無理やり$x_6$と$x_4$、さらに$x_6$と$x_5$に相関を持たせ、$x_4$と$x_5$が間接相関を持つようにします。これはもともと$x_4$と$x_5$がなかったものの、$x_6$の影響を受けて$x_6$の変動と連動して$x_4$と$x_5$の値も動くので本来相関がない変数同士が相関を持っているようにみえる状態になります。←(*)の箇所

python
    m = np.zeros(M)
    # 多変量正規分布の乱数生成用の共分散行列
    cov = [[   1, .29, .49, .10, .30,  0],
           [ .29,   1,-.49, .49, .40,  0],
           [ .49,-.49,   1,   0,   0,  0],
           [ .10, .49,   0,   1,   0,  0],
           [ .30, .40,   0,   0,   1,  0],
           [   0,   0,   0,   0,   0,  1]]

    # 6つの変数を作成
    X = st.multivariate_normal.rvs(mean=m, cov=cov, size=size, random_state=random_state)

    # x4とx6に相関を持たせる(*)
    X[:,3] += 0.6*X[:,5]
    
    # x5とx6に相関を持たせる(*)
    X[:,4] += 0.6*X[:,5]

図:間接相関のイメージ
スクリーンショット 2016-12-23 20.43.46.png

得られたデータをpairplotで散布図を書くと下記の通りです。

スクリーンショット 2016-12-23 20.48.58.png

相関行列と、偏相関行列を算出すると下記の通りです。
偏相関とは、先ほどの図にあったような間接相関の効果を除き、直接的な相関がどの程度かを表す値として利用できます。早速、偏相関がどうなっているか見てみたいと思います。

偏相関行列は精度行列(分散共分散行列の逆行列)$\Lambda$の要素$\lambda_{ij}$を使って表すと

\hat{\rho}_{ij} = {-\lambda_{ij} \over \sqrt{\lambda_{ii}\lambda_{jj}}}

となります。

partial_corr_est.png

右が偏相関行列ですが、データから素直に算出した分散共分散行列を使って計算するとノイズの影響を受けてしまい、よく分からない値となっています。あまりデータの構造が見えているようではありません。全部の変数が関係し合っているようにも見えてしまっていますね・・・。そんなはずはないんです。

そこで、ノイズの影響をL1正則化、つまりLassoを用いて除外し、推定した分散共分散行列、精度行列を使ってみたいと思います。

2. Graphical Lasso

多変量正規分布を想定したマルコフグラフをガウス型グラフィカルモデルと呼び、この分布のパラメーターである精度行列$\Lambda$を用いて変数の関係性をグラフィカルモデルとしてみることができます。つまり精度行列のi,j要素$\lambda_{ij}$が0でない場合、$x_i$と$x_j$との間に直接相関があります。このような状態のものを下記のようなグラフで表現し、グラフィカルモデルと呼びます。

スクリーンショット 2016-12-23 21.26.05.png

データ$\boldsymbol{x}$が多変量正規分布に従うとすると、その分布は

\mathcal{N}(\boldsymbol{x} | {\bf 0}, \Lambda^{-1}) \sim {|\Lambda|^{1/2} \over (2\pi)^{M/2}} \exp\left( -{1 \over 2} \boldsymbol{x}^\intercal \Lambda \boldsymbol{x} \right)

と表せます。ただしこのままだと0以外の値が入っていると直接相関があるとみなすことにしたので、大体の推定結果はノイズが乗ることも考えると先ほどのグラフ構造はほとんどの変数間に直接相関があることになってしまいます。なるべくスパースな精度行列を求められるよう工夫が必要です。
精度行列$\Lambda$にラプラス分布$p(\Lambda)$の事前分布を想定することでスパースな解が求まります。事後分布は

p(\Lambda|\boldsymbol{x}) \propto p(\Lambda)\prod_{n=1}^{N}\mathcal{N}(\boldsymbol{x}^{(n)}|{\bf 0}, \Lambda)

となるため、logを取って$\Lambda$で微分することで

\ln |\Lambda| - \mathrm{tr}(S\Lambda)-\alpha\|\Lambda\|_1

となり、ラプラス分布の事前分布を設定することは、L1正則化をかけていることになります。

3. Scikit-Learnを用いてGraphical Lassoの解を求める

Scikit-LearnにはこのGraphical Lassoを実装したGraphLassoが実装されています。これには座標降下法という最適化手法が用いられています。まずはこれを試してみましょう。

実装は非常に簡単です。いつものとおりfitするだけです:relaxed: すると分散共分散行列と精度行列を取得することができます。

python
alpha = 0.2 # L1正則化パラメーター
model = GraphLasso(alpha=alpha,
                     max_iter=100,                     
                     verbose=True,
                     assume_centered = True)

model.fit(X)
cov_ = model.covariance_ # 分散共分散行列
prec_ = model.precision_ # 精度行列

得られた分散共分散行列と精度行列は下記の通りです。
glasso_cov_prec.png

また、そこから計算した相関行列と、偏相関行列が下記です。
無理やり作り出した$x_4, x_5, x_6$の相関のうち、$x_4$と$x_5$の相関が偏相関行列上で0になっていることがわかります。
corr_pcorr_sklearn.png

さて、これでスパースな精度行列から偏相関行列を求めることができて、めでたしめでたしなのですが、本ブログはStan Advent Calendar 2016の記事です。そして、先ほど、このL1正則化は$\Lambda$の事前分布にラプラス分布を想定したものと解釈できることを書きました。そしたら、これをStanでやってみることができそうですね。

4. Stanを用いてGraphical Lassoの解を求める

ということで、stanコードを書いてこのGraphical Lassoと同じことをやってみます。stanではラプラス分布はdouble exponential分布と呼ばれているので、これを使います。stanコードは下記になります。

glasso.stan
data {
  int N;           // Sample size
  int P;           // Feature size
  matrix[N, P] X;  // Data
  real alpha;      // scale parameter of double exponential (L1 parameter)
}
parameters {
  corr_matrix[P] Lambda; // Covariance matrix
}
model {
  vector[P] zeros;
  for (i in 1:P) {
     zeros[i] = 0;
  }
  
  // Precision matrix follows laplace distribution
  to_vector(Lambda) ~ double_exponential(0, 1/alpha);
  
  for (j in 1:N){
    // X follows multi normal distribution
    X[j] ~ multi_normal(zeros, inverse(Lambda));
  }
}
generated quantities {
  matrix[P, P] Sigma;
  Sigma = inverse(Lambda);
}

これを呼び出すPythonコードはこちらです。

python
%time sm = pystan.StanModel(file='glasso.stan')
print('Compile finished.')

n_sample = 1000  # 1 chainあたりのサンプル数
n_warm   = 1000  # warm upに使う数
n_chain  = 4     # chain数
stan_data = {'N': X.shape[0], 'P': P, 'X': X, 'alpha': alpha}
%time fit = sm.sampling(data=stan_data, chains=n_chain, iter=n_sample+n_warm, warmup=n_warm)
print('Sampling finished.')

fit

結果がこちらです。分散共分散行列の対角要素のRhatがnanになってしまっているので要注意ですが、値はおかしくなさそうで、他の要素のRhatは全て1.0になっています。

out
Inference for Stan model: anon_model_31ac7e216f1b5eccff16f1394bd9827e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

              mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
Lambda[0,0]    1.0     0.0    0.0    1.0    1.0    1.0    1.0    1.0   4000    nan
Lambda[1,0]  -0.31  8.7e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.21   2891    1.0
Lambda[2,0]  -0.42  7.1e-4   0.04   -0.5  -0.45  -0.42  -0.39  -0.33   3735    1.0
Lambda[3,0]   0.04  1.0e-3   0.05  -0.07 2.1e-3   0.04   0.07   0.14   2808    1.0
Lambda[4,0]  -0.12  9.1e-4   0.05  -0.22  -0.15  -0.11  -0.08-9.5e-3   3437    1.0
Lambda[5,0]   0.02  1.0e-3   0.06  -0.09  -0.01   0.02   0.06   0.13   3014    1.0
Lambda[0,1]  -0.31  8.7e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.21   2891    1.0
Lambda[1,1]    1.0 1.5e-189.0e-17    1.0    1.0    1.0    1.0    1.0   3633    nan
Lambda[2,1]   0.47  6.3e-4   0.04   0.39   0.44   0.47    0.5   0.55   4000    1.0
Lambda[3,1]  -0.31  7.6e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.22   3810    1.0
Lambda[4,1]  -0.19  9.4e-4   0.05  -0.29  -0.22  -0.19  -0.15  -0.08   3021    1.0
Lambda[5,1]    0.2  8.8e-4   0.05    0.1   0.16    0.2   0.23    0.3   3395    1.0
Lambda[0,2]  -0.42  7.1e-4   0.04   -0.5  -0.45  -0.42  -0.39  -0.33   3735    1.0
Lambda[1,2]   0.47  6.3e-4   0.04   0.39   0.44   0.47    0.5   0.55   4000    1.0
Lambda[2,2]    1.0 3.6e-188.7e-17    1.0    1.0    1.0    1.0    1.0    594    nan
Lambda[3,2]  -0.11  8.9e-4   0.05  -0.22  -0.15  -0.11  -0.08  -0.01   3623    1.0
Lambda[4,2]  -0.04  9.1e-4   0.05  -0.15  -0.08  -0.04-5.8e-3   0.07   3642    1.0
Lambda[5,2]   0.03  9.0e-4   0.05  -0.08-9.2e-3   0.03   0.06   0.13   3495    1.0
Lambda[0,3]   0.04  1.0e-3   0.05  -0.07 2.1e-3   0.04   0.07   0.14   2808    1.0
Lambda[1,3]  -0.31  7.6e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.22   3810    1.0
Lambda[2,3]  -0.11  8.9e-4   0.05  -0.22  -0.15  -0.11  -0.08  -0.01   3623    1.0
Lambda[3,3]    1.0 2.0e-181.2e-16    1.0    1.0    1.0    1.0    1.0   3553    nan
Lambda[4,3]  -0.02  9.3e-4   0.06  -0.13  -0.06  -0.02   0.02   0.09   3503    1.0
Lambda[5,3]  -0.38  7.5e-4   0.04  -0.47  -0.41  -0.38  -0.35  -0.29   3591    1.0
Lambda[0,4]  -0.12  9.1e-4   0.05  -0.22  -0.15  -0.11  -0.08-9.5e-3   3437    1.0
Lambda[1,4]  -0.19  9.4e-4   0.05  -0.29  -0.22  -0.19  -0.15  -0.08   3021    1.0
Lambda[2,4]  -0.04  9.1e-4   0.05  -0.15  -0.08  -0.04-5.8e-3   0.07   3642    1.0
Lambda[3,4]  -0.02  9.3e-4   0.06  -0.13  -0.06  -0.02   0.02   0.09   3503    1.0
Lambda[4,4]    1.0 2.0e-181.2e-16    1.0    1.0    1.0    1.0    1.0   3633    nan
Lambda[5,4]  -0.36  7.2e-4   0.05  -0.45  -0.39  -0.36  -0.33  -0.27   4000    1.0
Lambda[0,5]   0.02  1.0e-3   0.06  -0.09  -0.01   0.02   0.06   0.13   3014    1.0
Lambda[1,5]    0.2  8.8e-4   0.05    0.1   0.16    0.2   0.23    0.3   3395    1.0
Lambda[2,5]   0.03  9.0e-4   0.05  -0.08-9.2e-3   0.03   0.06   0.13   3495    1.0
Lambda[3,5]  -0.38  7.5e-4   0.04  -0.47  -0.41  -0.38  -0.35  -0.29   3591    1.0
Lambda[4,5]  -0.36  7.2e-4   0.05  -0.45  -0.39  -0.36  -0.33  -0.27   4000    1.0
Lambda[5,5]    1.0 2.2e-181.3e-16    1.0    1.0    1.0    1.0    1.0   3381    nan
Sigma[0,0]    1.31  1.1e-3   0.07   1.19   1.26    1.3   1.35   1.45   3507    1.0
Sigma[1,0]    0.26  1.3e-3   0.08   0.11   0.21   0.27   0.32   0.43   4000    1.0
Sigma[2,0]    0.45  1.3e-3   0.08   0.29   0.39   0.44   0.51   0.62   4000    1.0
Sigma[3,0]     0.1  1.2e-3   0.08  -0.05   0.05    0.1   0.15   0.25   4000    1.0
Sigma[4,0]    0.23  1.2e-3   0.08   0.09   0.18   0.23   0.28   0.38   4000    1.0
Sigma[5,0]    0.03  1.2e-3   0.08  -0.13  -0.02   0.03   0.08   0.18   4000    1.0
Sigma[0,1]    0.26  1.3e-3   0.08   0.11   0.21   0.27   0.32   0.43   4000    1.0
Sigma[1,1]    1.55  1.5e-3   0.09   1.38   1.48   1.54   1.61   1.74   4000    1.0
Sigma[2,1]   -0.56  1.4e-3   0.09  -0.74  -0.62  -0.56  -0.49  -0.39   4000    1.0
Sigma[3,1]    0.41  1.3e-3   0.08   0.24   0.35    0.4   0.46   0.57   4000    1.0
Sigma[4,1]    0.29  1.3e-3   0.08   0.14   0.24   0.29   0.34   0.46   4000    1.0
Sigma[5,1]   -0.04  1.3e-3   0.08   -0.2  -0.09  -0.04   0.02   0.13   4000    1.0
Sigma[0,2]    0.45  1.3e-3   0.08   0.29   0.39   0.44   0.51   0.62   4000    1.0
Sigma[1,2]   -0.56  1.4e-3   0.09  -0.74  -0.62  -0.56  -0.49  -0.39   4000    1.0
Sigma[2,2]    1.47  1.3e-3   0.08   1.32   1.41   1.46   1.52   1.65   4000    1.0
Sigma[3,2]  2.9e-3  1.3e-3   0.08  -0.15  -0.05 1.3e-3   0.06   0.16   4000    1.0
Sigma[4,2]    0.04  1.2e-3   0.08  -0.12  -0.02   0.03   0.09   0.19   4000    1.0
Sigma[5,2]    0.07  1.3e-3   0.08  -0.08   0.02   0.08   0.13   0.23   4000    1.0
Sigma[0,3]     0.1  1.2e-3   0.08  -0.05   0.05    0.1   0.15   0.25   4000    1.0
Sigma[1,3]    0.41  1.3e-3   0.08   0.24   0.35    0.4   0.46   0.57   4000    1.0
Sigma[2,3]  2.9e-3  1.3e-3   0.08  -0.15  -0.05 1.3e-3   0.06   0.16   4000    1.0
Sigma[3,3]    1.36  1.1e-3   0.07   1.23    1.3   1.35    1.4   1.51   4000    1.0
Sigma[4,3]    0.31  1.2e-3   0.08   0.17   0.26   0.31   0.36   0.47   4000    1.0
Sigma[5,3]    0.55  1.4e-3   0.09   0.39   0.49   0.55    0.6   0.73   4000    1.0
Sigma[0,4]    0.23  1.2e-3   0.08   0.09   0.18   0.23   0.28   0.38   4000    1.0
Sigma[1,4]    0.29  1.3e-3   0.08   0.14   0.24   0.29   0.34   0.46   4000    1.0
Sigma[2,4]    0.04  1.2e-3   0.08  -0.12  -0.02   0.03   0.09   0.19   4000    1.0
Sigma[3,4]    0.31  1.2e-3   0.08   0.17   0.26   0.31   0.36   0.47   4000    1.0
Sigma[4,4]    1.29  9.9e-4   0.06   1.19   1.25   1.29   1.33   1.43   4000    1.0
Sigma[5,4]    0.53  1.3e-3   0.08   0.38   0.47   0.52   0.58    0.7   4000    1.0
Sigma[0,5]    0.03  1.2e-3   0.08  -0.13  -0.02   0.03   0.08   0.18   4000    1.0
Sigma[1,5]   -0.04  1.3e-3   0.08   -0.2  -0.09  -0.04   0.02   0.13   4000    1.0
Sigma[2,5]    0.07  1.3e-3   0.08  -0.08   0.02   0.08   0.13   0.23   4000    1.0
Sigma[3,5]    0.55  1.4e-3   0.09   0.39   0.49   0.55    0.6   0.73   4000    1.0
Sigma[4,5]    0.53  1.3e-3   0.08   0.38   0.47   0.52   0.58    0.7   4000    1.0
Sigma[5,5]    1.42  1.3e-3   0.08   1.28   1.36   1.41   1.47   1.59   4000    1.0
lp__        -713.2    0.06   2.67 -719.3 -714.9 -712.9 -711.2 -709.0   1983    1.0

Samples were drawn using NUTS at Sat Dec 24 00:05:39 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

これだけだとちょっと見づらいので、グラフを書いてみましょう。

python
# 推定したパラメーターの取り出し
Lambda = fit.extract()["Lambda"]
Sigma  = fit.extract()["Sigma"]

# EAP推定量の算出
EAP_Sigma  = np.mean(Sigma, axis=0)
EAP_Lambda = np.mean(Lambda, axis=0)

# EAP推定量の可視化
plt.figure(figsize=(10,4))
ax = plt.subplot(121)
sns.heatmap(pd.DataFrame(EAP_Sigma), annot=EAP_Sigma, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Graphical Lasso with Stan: Covariance matrix")

ax = plt.subplot(122)
sns.heatmap(pd.DataFrame(EAP_Lambda), annot=EAP_Lambda, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Graphical Lasso with Stan: Precision matrix")
plt.savefig(img_path+"glasso_stan_cov_prec.png", dpi=128)
plt.show()

glasso_stan_cov_prec.png

python
# 相関行列の算出
EAP_cor = np.empty_like(EAP_Sigma)
for i in range(P):
    for j in range(P):
        EAP_cor[i, j] = EAP_Sigma[i, j]/np.sqrt(EAP_Sigma[i, i]*EAP_Sigma[j, j])
        
# 偏相関行列の算出
EAP_rho = np.empty_like(EAP_Lambda)
for i in range(P):
    for j in range(P):
        EAP_rho[i, j] = -EAP_Lambda[i, j]/np.sqrt(EAP_Lambda[i, i]*EAP_Lambda[j, j])
        
plt.figure(figsize=(11,4))
ax = plt.subplot(122)
sns.heatmap(pd.DataFrame(EAP_rho), annot=EAP_rho, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Partial correlation Coefficiant with stan")
#plt.savefig(img_path+"partial_corr_sklearn.png", dpi=128)

ax = plt.subplot(121)
sns.heatmap(pd.DataFrame(EAP_cor), annot=EAP_cor, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Correlation Coefficiant with stan")
plt.savefig(img_path+"corr_pcorr_stan.png", dpi=128)
plt.show()

当たり前ですが、乱数シミュレーションのため、完全に0になる要素はありません。そういった意味ではあまりL1正則化の効力はStanシミュレーションでは得られないということになりますね。
corr_pcorr_stan.png

こちらは先ほどのScikit-Learnのもの。比較対象用。

corr_pcorr_sklearn.png

ちょっと値が違うのですが、似た構造になっています。
$x_4$と$x_5$の間接相関も偏相関行列ではちゃんと消えています。
サンプリングしたヒストグラムとScikit-Learnの結果を重ねたものを下記に描画しました。

サンプリング結果の可視化

サンプリング結果のヒストグラム
grid_dist_plot1.png

拡大バージョン。
赤い線がScikit-Learnの結果。点線が事後分布の2.5%点と 97.5%点。幾つかは外れてしまっていますが、そこそこの割合で区間内に入っている様子が見て取れます。なので、だいたい同じ結果(対角要素を除く)と言えるのではないでしょうか。全てが確信区間に入っているわけではないので、もう少しチューニングが必要かもしれません。
grid_dist_plot2.png

Trace Plot
trace_plot.png

5. おわりに

Graphical LassoのL1正則化が、パラメーターの事前分布をLaplace分布にしたものだ、ということを知りそれをStanで試してみたくて書いた記事でした。近しい構造の偏相関行列が得られましたが、Scikit-Learnの結果とズレが少しあるので、もう少し調べてみたいと思います。

参考

「異常検知と変化検知(機械学習プロフェッショナルシリーズ)」井手剛、杉山将
Stan Modeling Language User’s Guide and Reference Manual
 ⇒ http://www.uvm.edu/~bbeckage/Teaching/DataAnalysis/Manuals/stan-reference-2.8.0.pdf
偏相関係数(Partial Correration Coefficient)  
 ⇒ http://www.ae.keio.ac.jp/lab/soc/takeuchi/lectures/5_Parcor.pdf

55
45
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
45

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?