0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

tidymodels で分類モデルをハイパーパラメータ・チューニングして学習して予測して評価する

Posted at

なるべくデフォルトに丸投げして、統一的な手法で解析してみる。
ames データセットの Central_Air の予測を考える。

事前準備

Calibration Plot に必要なパッケージをインストールする。

remotes::install_github("tidymodels/probably")

パッケージを読み込む(必要ならインストールしておく)。

library(tidyverse)
library(tidymodels)

解析用の関数を準備する。

cv_fit_pred <- function(
        recipe, spec, df_train, df_test, df_cv = NULL, grid = 10
) {
    
    if(is.null(df_cv)) {
        
        params_grid = NULL
        
        # ワークフローを設定する
        wf <- 
            workflow() |> 
            workflows::add_recipe(recipe = recipe) |> 
            workflows::add_model(spec = spec)
        
    } else {
        
        # 交差検証用のワークフローを設定する
        cv_wf <- 
            workflow() |> 
            workflows::add_recipe(recipe = recipe) |> 
            workflows::add_model(spec = spec)
        
        # 最適なハイパーパラメータを探索する
        params_grid <- 
            cv_wf |> 
            tune::tune_grid(
                resamples = df_cv,
                grid = grid,
                control = tune::control_grid(save_pred = TRUE)
            )
        
        # 最適なハイパーパラメータによるワークフローを設定する
        wf <- 
            cv_wf |> 
            tune::finalize_workflow(
                parameters = params_grid |> tune::select_best()
            )
        
    }
    
    # ワークフローを学習する
    wf_fit <- 
        wf |> 
        parsnip::fit(data = df_train)
    
    # 学習したワークフローで予測する
    pred <- 
        wf_fit |> 
        tune::augment(new_data = df_test)
    
    return(
        list(
            params_grid = params_grid, 
            wf_fit = wf_fit, 
            pred = pred
        )
    )
    
}

parsnip::parsnip_addin() を使って、ハイパーパラメータ・チューニングを前提としてモデルを指定する。

# 決定木
rpart_spec <- 
    parsnip::decision_tree() |> 
    parsnip::set_args(
        tree_depth = tune(), min_n = tune(), cost_complexity = tune()
    ) |> 
    parsnip::set_engine('rpart') |> 
    parsnip::set_mode('classification')

# ランダム・フォレスト
rf_spec <-
    parsnip::rand_forest() |> 
    parsnip::set_args(mtry = tune(), min_n = tune(), trees = 5000) |> 
    parsnip::set_engine(
        'ranger',
        num.threads = parallel::detectCores()
    ) |> 
    parsnip::set_mode('classification')

# ラッソ回帰
lasso_spec <-
    logistic_reg() |> 
    parsnip::set_args(penalty = tune(), mixuture = 1) |> 
    parsnip::set_engine('glmnet') |> 
    parsnip::set_mode('classification')

# ロジスティック回帰
lr_spec <-
    parsnip::logistic_reg() |> 
    parsnip::set_engine('glm') |> 
    parsnip::set_mode('classification')

データを準備する。

data("ames")

ames_split <- 
    ames |> 
    rsample::initial_split(strata = Central_Air)

ames_train <- 
    ames_split |> 
    rsample::training()

ames_test <- 
    ames_split |> 
    rsample::testing()

ames_cv <- 
    ames_train |> 
    rsample::vfold_cv(v = 10, strata = Central_Air)

ハイパーパラメータ・チューニング・学習・予測

決定木で学習する。

ames_rec <- 
    ames_train |> 
    recipes::recipe(Central_Air ~ .) |> 
    recipes::step_zv(all_predictors()) |> 
    prep()

rpart_result <- 
    ames_rec |> 
    cv_fit_pred(rpart_spec, ames_train, ames_test, ames_cv)

# ハイパーパラメータ・チューニングをチェック
rpart_result$params_grid |> 
    tune::show_best()

# 決定木の図示
rpart_result$wf |> 
    workflows::extract_fit_engine() |> 
    rpart.plot::rpart.plot()

# 推測結果
rpart_result$pred

ランダム・フォレストで学習する。

ames_rec <- 
    ames_train |> 
    recipes::recipe(Central_Air ~ .) |> 
    recipes::step_zv(all_predictors()) |> 
    prep()

rf_result <- 
    ames_rec |> 
    cv_fit_pred(rf_spec, ames_train, ames_test, ames_cv)

# ハイパーパラメータ・チューニングをチェック
rf_result$params_grid |> 
    tune::show_best()

# 推測結果
rf_result$pred

ラッソ回帰で学習する。

# `step_dummy` でダミー変数化する
ames_rec <- 
    ames_train |> 
    recipes::recipe(Central_Air ~ .) |> 
    recipes::step_ordinalscore(all_ordered_predictors()) |> 
    recipes::step_dummy(all_unordered_predictors()) |> 
    recipes::step_zv(all_predictors()) |> 
    recipes::step_normalize(all_numeric_predictors()) |> 
    prep()

lasso_result <- 
    ames_rec |> 
    cv_fit_pred(lasso_spec, ames_train, ames_test, ames_cv)

# ハイパーパラメータ・チューニングをチェック
lasso_result$params_grid |> 
    tune::show_best()

# 係数をチェック
lasso_result$wf_fit |> 
    workflows::extract_fit_parsnip() |> 
    broom::tidy()

# 推測結果
lasso_result$pred

ロジスティック回帰で学習する(ハイパーパラメータ・チューニングなし)。

# `step_other` で頻度が少ない項目(5%未満)に対処する
ames_rec <- 
    ames_train |> 
    recipes::recipe(Central_Air ~ .) |> 
    recipes::step_other(all_nominal_predictors()) |> 
    recipes::step_zv(all_predictors()) |> 
    prep()

lr_result <- 
    ames_rec |> 
    cv_fit_pred(lr_spec, ames_train, ames_test)

# 係数をチェック
lr_result$wf_fit |> 
    workflows::extract_fit_engine() |> 
    broom::tidy()

# 推測結果
lr_result$pred

評価

解析用の関数を準備する。

plot_evaluation <- function(df, truth, estimate) {
    
    # ROC Curve を作成する
    p_roc <-
        df |> 
        dplyr::rename(truth = {{ truth }}) |>
        yardstick::roc_curve(truth = truth, {{ estimate }}) |>
        ggplot2::ggplot(aes(x = 1 - specificity, y = sensitivity)) +
        geom_path() +
        geom_abline(linetype = 2) +
        coord_obs_pred() +
        cowplot::theme_cowplot() +
        cowplot::background_grid()
    
    # PR Curve を作成する
    p_pr <-
        df |>
        dplyr::rename(truth = {{ truth }}) |>
        yardstick::pr_curve(truth = truth, {{ estimate }}) |>
        ggplot2::ggplot(aes(x = recall, y = precision)) +
        geom_path() +
        geom_abline(slope = -1, intercept = 1, linetype = 2) +
        coord_obs_pred() +
        cowplot::theme_cowplot() +
        cowplot::background_grid()
    
    # Calibration Plot を作成する
    p_calibration <-
        df |>
        dplyr::rename(truth = {{ truth }}) |>
        probably::cal_plot_breaks(truth = truth, {{ estimate }}) +
        coord_obs_pred() +
        cowplot::theme_cowplot() +
        cowplot::background_grid()
    
    return(
        list(
            roc_curve = p_roc, 
            pr_curve = p_pr, 
            calibration_plot = p_calibration
        )
    )
    
}

ROC Curve, PR Curve, Calibration Plot を図示する。

p <- 
    rf_result$pred |> 
    plot_evaluation(truth = "Central_Air", estimate = ".pred_N")

p

評価指標を調べる。

# accuracy 等を確認する
rf_result$pred |> 
    yardstick::metrics(truth = Central_Air, estimate = .pred_class)

# 混同行列を確認する
rf_result$pred |> 
    yardstick::conf_mat(truth = Central_Air, estimate = .pred_class)
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?