5
3

More than 1 year has passed since last update.

tidyposteriorの紹介

Last updated at Posted at 2021-12-11

1.はじめに

 tidymodelsを使って作成したモデルを評価するために交差検証法が用いられますが、どのように分割するかで、交差検証法を実施するたびに微妙に結果が異なるため、不確実性を示す統計的な定量がほしくなります。
 このtidyposteriorを使えば、ベイズ統計の枠組みで評価することができます。
tidymodelsの他、分析に必要なパッケージは別途インストールしてください。(rstan,rstanarmがバックエンドで動きます)

2.tidyposterior

このパッケージは、モデルによって生成されたリサンプリング結果の事後解析を行うために用います。

例えば、2つのモデルを10回の交差検証法を用いてRMSEで評価すると10個のペア統計ができますが、それらを使って別途テストデータを使用せずにモデル間の比較を行うことができます。
本パッケージはBenavoli et al (2017)に従って作成しています。
tidyposteriorは、ベイズ一般化線形モデルを使用しており、caret::resamples()関数のアップグレード版とも言えます。

3.パッケージのインストールの方法

CRANからインストールできます。

install.packages("tidyposterior")

4.例題

シンプルな2クラス問題を例に10分割交差検証法を使ってモデルの性能を評価します。
評価指標は、roc_aucを用います。

library(tidyverse)
library(tidymodels)
library(tidyposterior)
library(discrim)

データをセットします。

data(two_class_dat, package = "modeldata")

set.seed(123)
folds <- vfold_cv(two_class_dat)

データの概要を見てみます。

ggplot(two_class_dat, aes(x=A,y=B)) + geom_point(aes(color=Class)) + theme_bw()

image.png

2つの異なったモデルを作成します。(単純にするために、パラメータのチューニングは行いません)

#ロジスティック回帰
logistic_reg_glm_spec <-
  logistic_reg() %>%
  set_engine('glm')

#Linear discriminant analysis(LDA) 線形判別分析
lda_spec <-
  discrim_linear() %>%
  set_engine("mda") 

※tidymodelsによるモデルの指定方法は以下を参照
https://www.tidymodels.org/find/parsnip/

tidumodelsでは、分割されたリサンプルにそれぞれfitさせるには、fit_resample() を用います。

rs_ctrl <- control_resamples(save_workflow = TRUE)

logistic_reg_glm_res <- 
  logistic_reg_glm_spec %>% 
  fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)

lda_res <- 
  lda_spec %>% 
  fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)

結果をperf_mod()関数に渡します。
collect_metrics()を用いて、評価します。

logistic_roc <- 
  collect_metrics(logistic_reg_glm_res, summarize = FALSE) %>% 
  dplyr::filter(.metric == "roc_auc") %>% 
  dplyr::select(id, logistic = .estimate)

lda_roc <- 
  collect_metrics(lda_res, summarize = FALSE) %>% 
  dplyr::filter(.metric == "roc_auc") %>% 
  dplyr::select(id, lda = .estimate)

resamples_df <- full_join(logistic_roc, lda_roc, by = "id")
resamples_df
>id     logistic   lda
>   <chr>     <dbl> <dbl>
> 1 Fold01    0.88  0.874
> 2 Fold02    0.894 0.893
> 3 Fold03    0.936 0.940
> 4 Fold04    0.861 0.862
> 5 Fold05    0.854 0.854
> 6 Fold06    0.901 0.891
> 7 Fold07    0.890 0.893
> 8 Fold08    0.867 0.865
> 9 Fold09    0.892 0.892
>10 Fold10    0.894 0.898

通常、このリサンプルされた10Foldsの平均でモデルの評価をするところですが、事後分布を解析するため、この結果を直接perf_mod()に渡します。

set.seed(123)
roc_model_via_df <- perf_mod(resamples_df, iter = 2000)
>SAMPLING FOR MODEL 'continuous' NOW (CHAIN 1).
>Chain 1: 
>Chain 1: Gradient evaluation took 0 seconds
>Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0 seconds.
>Chain 1: Adjust your expectations accordingly!
>Chain 1: 
>Chain 1: 
>Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
>Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
>Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
>Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
>Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
>Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
>Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
>Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
>Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
>Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
>Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
>Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
>Chain 1: 
>Chain 1:  Elapsed Time: 2.698 seconds (Warm-up)
>Chain 1:                1.125 seconds (Sampling)
>Chain 1:                3.823 seconds (Total)
>Chain 1: 
>
>SAMPLING FOR MODEL 'continuous' NOW (CHAIN 2).
>Chain 2: 
>Chain 2: Gradient evaluation took 0 seconds
>Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0 seconds.
>Chain 2: Adjust your expectations accordingly!
>Chain 2: 
>Chain 2: 
>Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
>Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
>Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
>Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
>Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
>Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
>Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
>Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
>Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
>Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
>Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
>Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
>Chain 2: 
>Chain 2:  Elapsed Time: 2.751 seconds (Warm-up)
>Chain 2:                0.641 seconds (Sampling)
>Chain 2:                3.392 seconds (Total)
>Chain 2: 
>
>SAMPLING FOR MODEL 'continuous' NOW (CHAIN 3).
>Chain 3: 
>Chain 3: Gradient evaluation took 0 seconds
>Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0 seconds.
>Chain 3: Adjust your expectations accordingly!
>Chain 3: 
>Chain 3: 
>Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
>Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
>Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
>Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
>Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
>Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
>Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
>Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
>Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
>Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
>Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
>Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
>Chain 3: 
>Chain 3:  Elapsed Time: 2.393 seconds (Warm-up)
>Chain 3:                0.888 seconds (Sampling)
>Chain 3:                3.281 seconds (Total)
>Chain 3: 
>
>SAMPLING FOR MODEL 'continuous' NOW (CHAIN 4).
>Chain 4: 
>Chain 4: Gradient evaluation took 0 seconds
>Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0 seconds.
>Chain 4: Adjust your expectations accordingly!
>Chain 4: 
>Chain 4: 
>Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
>Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
>Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
>Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
>Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
>Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
>Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
>Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
>Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
>Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
>Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
>Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
>Chain 4: 
>Chain 4:  Elapsed Time: 3.804 seconds (Warm-up)
>Chain 4:                1.075 seconds (Sampling)
>Chain 4:                4.879 seconds (Total)
>Chain 4: 
>Warning: There were 1 divergent transitions after warmup. See
>http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
>to find out why this is a problem and how to eliminate them.
>Warning: Examine the pairs() plot to diagnose sampling problems

どのようなモデルを使ったかを見てみます。

summary(roc_model_via_df)
>Model Info:
> function:     stan_glmer
> family:       gaussian [identity]
> formula:      statistic ~ model + (1 | id)
> algorithm:    sampling
> sample:       4000 (posterior sample size)
> priors:       see help('prior_summary')
> observations: 20
> groups:       id (10)
>
>Estimates:
>                                    mean   sd   10%
>(Intercept)                       0.9    0.0  0.9  
>modellogistic                     0.0    0.0  0.0  
>b[(Intercept) id:Fold01]          0.0    0.0  0.0  
>b[(Intercept) id:Fold02]          0.0    0.0  0.0  
>b[(Intercept) id:Fold03]          0.0    0.0  0.0  
>b[(Intercept) id:Fold04]          0.0    0.0  0.0  
>b[(Intercept) id:Fold05]          0.0    0.0  0.0  
>b[(Intercept) id:Fold06]          0.0    0.0  0.0  
>b[(Intercept) id:Fold07]          0.0    0.0  0.0  
>b[(Intercept) id:Fold08]          0.0    0.0  0.0  
>b[(Intercept) id:Fold09]          0.0    0.0  0.0  
>b[(Intercept) id:Fold10]          0.0    0.0  0.0  
>sigma                             0.0    0.0  0.0  
>Sigma[id:(Intercept),(Intercept)] 0.0    0.0  0.0  
>                                    50%   90%
>(Intercept)                       0.9   0.9  
>modellogistic                     0.0   0.0  
>b[(Intercept) id:Fold01]          0.0   0.0  
>b[(Intercept) id:Fold02]          0.0   0.0  
>b[(Intercept) id:Fold03]          0.0   0.1  
>b[(Intercept) id:Fold04]          0.0   0.0  
>b[(Intercept) id:Fold05]          0.0   0.0  
>b[(Intercept) id:Fold06]          0.0   0.0  
>b[(Intercept) id:Fold07]          0.0   0.0  
>b[(Intercept) id:Fold08]          0.0   0.0  
>b[(Intercept) id:Fold09]          0.0   0.0  
>b[(Intercept) id:Fold10]          0.0   0.0  
>sigma                             0.0   0.0  
>Sigma[id:(Intercept),(Intercept)] 0.0   0.0  
>
>Fit Diagnostics:
>           mean   sd   10%   50%   90%
>mean_PPD 0.9    0.0  0.9   0.9   0.9  
>
>The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see >help('summary.stanreg')).
>
>MCMC diagnostics
>                                  mcse Rhat n_eff
>(Intercept)                       0.0  1.0   602 
>modellogistic                     0.0  1.0  2124 
>b[(Intercept) id:Fold01]          0.0  1.0   689 
>b[(Intercept) id:Fold02]          0.0  1.0   736 
>b[(Intercept) id:Fold03]          0.0  1.0   687 
>b[(Intercept) id:Fold04]          0.0  1.0   688 
>b[(Intercept) id:Fold05]          0.0  1.0   688 
>b[(Intercept) id:Fold06]          0.0  1.0   660 
>b[(Intercept) id:Fold07]          0.0  1.0   678 
>b[(Intercept) id:Fold08]          0.0  1.0   724 
>b[(Intercept) id:Fold09]          0.0  1.0   693 
>b[(Intercept) id:Fold10]          0.0  1.0   683 
>sigma                             0.0  1.0   769 
>Sigma[id:(Intercept),(Intercept)] 0.0  1.0  1450 
>mean_PPD                          0.0  1.0  4202 
>log-posterior                     0.2  1.0   606 
>
>For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is >the potential scale reduction factor on split chains (at convergence Rhat=1).

tidy()で図示化します。

roc_model_via_df %>% 
  tidy() %>% 
  ggplot(aes(x = posterior)) + 
  geom_histogram(bins = 40, col = "blue", fill = "blue", alpha = .4) + 
  facet_wrap(~ model, ncol = 1) + 
  xlab("Area Under the ROC Curve")

image.png

少しだけロジスティック回帰の分類性能が優れてそうです。
次に、実際に10分割交差検証法を10回繰り返した結果と比較してみます。
(合計100回の交差検証法の結果をプロットする)

library(patchwork)
for(i in 1:9){
logistic_reg_glm_res <- 
  logistic_reg_glm_spec %>% 
  fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)

lda_res <- 
  lda_spec %>% 
  fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)

logistic_roc <- 
  collect_metrics(logistic_reg_glm_res, summarize = FALSE) %>% 
  dplyr::filter(.metric == "roc_auc") %>% 
  dplyr::select(id, logistic = .estimate)

lda_roc <- 
  collect_metrics(lda_res, summarize = FALSE) %>% 
  dplyr::filter(.metric == "roc_auc") %>% 
  dplyr::select(id, lda = .estimate)

resamples_df2 <- full_join(logistic_roc, lda_roc, by = "id")
resamples_df <- bind_rows(resamples_df,resamples_df2)
}

p1 <- ggplot(resamples_df, aes(x=lda)) + 
  geom_histogram(bins = 10, col = "blue", fill = "blue", alpha = .4) 

p2 <- ggplot(resamples_df, aes(x=logistic)) + 
  geom_histogram(bins = 10, col = "blue", fill = "blue", alpha = .4)

p1 | p2

image.png
実際に交差検証法を繰り返した分布も、同じような結果になりました。

5.参考

tidyposterior

6.Enjoy

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3