なるべくデフォルトに丸投げして、統一的な手法で解析してみる。
ames
データセットの Central_Air
の予測を考える。
事前準備
Calibration Plot に必要なパッケージをインストールする。
remotes::install_github("tidymodels/probably")
パッケージを読み込む(必要ならインストールしておく)。
library(tidyverse)
library(tidymodels)
解析用の関数を準備する。
cv_fit_pred <- function(
recipe, spec, df_train, df_test, df_cv = NULL, grid = 10
) {
if(is.null(df_cv)) {
params_grid = NULL
# ワークフローを設定する
wf <-
workflow() |>
workflows::add_recipe(recipe = recipe) |>
workflows::add_model(spec = spec)
} else {
# 交差検証用のワークフローを設定する
cv_wf <-
workflow() |>
workflows::add_recipe(recipe = recipe) |>
workflows::add_model(spec = spec)
# 最適なハイパーパラメータを探索する
params_grid <-
cv_wf |>
tune::tune_grid(
resamples = df_cv,
grid = grid,
control = tune::control_grid(save_pred = TRUE)
)
# 最適なハイパーパラメータによるワークフローを設定する
wf <-
cv_wf |>
tune::finalize_workflow(
parameters = params_grid |> tune::select_best()
)
}
# ワークフローを学習する
wf_fit <-
wf |>
parsnip::fit(data = df_train)
# 学習したワークフローで予測する
pred <-
wf_fit |>
tune::augment(new_data = df_test)
return(
list(
params_grid = params_grid,
wf_fit = wf_fit,
pred = pred
)
)
}
parsnip::parsnip_addin()
を使って、ハイパーパラメータ・チューニングを前提としてモデルを指定する。
# 決定木
rpart_spec <-
parsnip::decision_tree() |>
parsnip::set_args(
tree_depth = tune(), min_n = tune(), cost_complexity = tune()
) |>
parsnip::set_engine('rpart') |>
parsnip::set_mode('classification')
# ランダム・フォレスト
rf_spec <-
parsnip::rand_forest() |>
parsnip::set_args(mtry = tune(), min_n = tune(), trees = 5000) |>
parsnip::set_engine(
'ranger',
num.threads = parallel::detectCores()
) |>
parsnip::set_mode('classification')
# ラッソ回帰
lasso_spec <-
logistic_reg() |>
parsnip::set_args(penalty = tune(), mixuture = 1) |>
parsnip::set_engine('glmnet') |>
parsnip::set_mode('classification')
# ロジスティック回帰
lr_spec <-
parsnip::logistic_reg() |>
parsnip::set_engine('glm') |>
parsnip::set_mode('classification')
データを準備する。
data("ames")
ames_split <-
ames |>
rsample::initial_split(strata = Central_Air)
ames_train <-
ames_split |>
rsample::training()
ames_test <-
ames_split |>
rsample::testing()
ames_cv <-
ames_train |>
rsample::vfold_cv(v = 10, strata = Central_Air)
ハイパーパラメータ・チューニング・学習・予測
決定木で学習する。
ames_rec <-
ames_train |>
recipes::recipe(Central_Air ~ .) |>
recipes::step_zv(all_predictors()) |>
prep()
rpart_result <-
ames_rec |>
cv_fit_pred(rpart_spec, ames_train, ames_test, ames_cv)
# ハイパーパラメータ・チューニングをチェック
rpart_result$params_grid |>
tune::show_best()
# 決定木の図示
rpart_result$wf |>
workflows::extract_fit_engine() |>
rpart.plot::rpart.plot()
# 推測結果
rpart_result$pred
ランダム・フォレストで学習する。
ames_rec <-
ames_train |>
recipes::recipe(Central_Air ~ .) |>
recipes::step_zv(all_predictors()) |>
prep()
rf_result <-
ames_rec |>
cv_fit_pred(rf_spec, ames_train, ames_test, ames_cv)
# ハイパーパラメータ・チューニングをチェック
rf_result$params_grid |>
tune::show_best()
# 推測結果
rf_result$pred
ラッソ回帰で学習する。
# `step_dummy` でダミー変数化する
ames_rec <-
ames_train |>
recipes::recipe(Central_Air ~ .) |>
recipes::step_ordinalscore(all_ordered_predictors()) |>
recipes::step_dummy(all_unordered_predictors()) |>
recipes::step_zv(all_predictors()) |>
recipes::step_normalize(all_numeric_predictors()) |>
prep()
lasso_result <-
ames_rec |>
cv_fit_pred(lasso_spec, ames_train, ames_test, ames_cv)
# ハイパーパラメータ・チューニングをチェック
lasso_result$params_grid |>
tune::show_best()
# 係数をチェック
lasso_result$wf_fit |>
workflows::extract_fit_parsnip() |>
broom::tidy()
# 推測結果
lasso_result$pred
ロジスティック回帰で学習する(ハイパーパラメータ・チューニングなし)。
# `step_other` で頻度が少ない項目(5%未満)に対処する
ames_rec <-
ames_train |>
recipes::recipe(Central_Air ~ .) |>
recipes::step_other(all_nominal_predictors()) |>
recipes::step_zv(all_predictors()) |>
prep()
lr_result <-
ames_rec |>
cv_fit_pred(lr_spec, ames_train, ames_test)
# 係数をチェック
lr_result$wf_fit |>
workflows::extract_fit_engine() |>
broom::tidy()
# 推測結果
lr_result$pred
評価
解析用の関数を準備する。
plot_evaluation <- function(df, truth, estimate) {
# ROC Curve を作成する
p_roc <-
df |>
dplyr::rename(truth = {{ truth }}) |>
yardstick::roc_curve(truth = truth, {{ estimate }}) |>
ggplot2::ggplot(aes(x = 1 - specificity, y = sensitivity)) +
geom_path() +
geom_abline(linetype = 2) +
coord_obs_pred() +
cowplot::theme_cowplot() +
cowplot::background_grid()
# PR Curve を作成する
p_pr <-
df |>
dplyr::rename(truth = {{ truth }}) |>
yardstick::pr_curve(truth = truth, {{ estimate }}) |>
ggplot2::ggplot(aes(x = recall, y = precision)) +
geom_path() +
geom_abline(slope = -1, intercept = 1, linetype = 2) +
coord_obs_pred() +
cowplot::theme_cowplot() +
cowplot::background_grid()
# Calibration Plot を作成する
p_calibration <-
df |>
dplyr::rename(truth = {{ truth }}) |>
probably::cal_plot_breaks(truth = truth, {{ estimate }}) +
coord_obs_pred() +
cowplot::theme_cowplot() +
cowplot::background_grid()
return(
list(
roc_curve = p_roc,
pr_curve = p_pr,
calibration_plot = p_calibration
)
)
}
ROC Curve, PR Curve, Calibration Plot を図示する。
p <-
rf_result$pred |>
plot_evaluation(truth = "Central_Air", estimate = ".pred_N")
p
評価指標を調べる。
# accuracy 等を確認する
rf_result$pred |>
yardstick::metrics(truth = Central_Air, estimate = .pred_class)
# 混同行列を確認する
rf_result$pred |>
yardstick::conf_mat(truth = Central_Air, estimate = .pred_class)