2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

mlr3を使った機械学習

Last updated at Posted at 2021-07-31

1.はじめに

Rでは、機械学習のフレームワークにcaretが有名ですが、登場してから少し古くなっています。mlrがcaretと並ぶ機械学習のフレームワークとして有名ですが、2019年からメンテナンスモードになり、mlr3に移行しています。また、最新ではtidymodels も登場しています。
個人的にはmlrは、機械学習のフレームワークとしては気に入っています。
しかし、後継パッケージのmlr3については、現在(2024/11月)、日本語による情報は少ないので、参考にある情報を元に紹介したいと思います。
更新しました:2024年11月15日

参考
Machine Learning con R y mlr3

mlr3は、単一のフレームワークの下でさまざまなパッケージの何百もの機能を統合するインターフェイスであり、予測モデルの前処理、トレーニング、最適化、および検証のすべての段階が可能になります。

このドキュメントで使用されているパッケージは次のとおりです。

library(mlr3verse)
library(tidyverse)
library(skimr)
library(DataExplorer)
library(ggpubr)
library(univariateML)
library(GGally)
library(mosaicData)

2.データ

SaratogaHousesパッケージデータセットにmosaicDataは、2006年に米国ニューヨーク州サラトガ郡にある1,728戸の住宅の価格に関する情報が含まれています。
価格に加えて、15の説明変数が含まれています。

  • price:住宅価格。
  • lotSize:家の大きさ(平方メートル)。
  • age:家の築年数。
  • landValue:土地の価値。
  • livingArea:居住可能な平方メートル。
  • pctCollege:大学の学位を持つ近所の割合。
  • bedrooms:寝室の数。
  • firplaces:煙突の数。
  • bathrooms:バスルームの数(値0.5は、シャワーのないバスルームを示します)。
  • rooms: 室数。
  • heating:暖房の種類。
  • fuel:暖房供給のタイプ(ガス、電気またはディーゼル)。
  • sewer:ドレインのタイプ。
  • waterfront:家から湖の景色が見えるかどうか。
  • newConstruction:家が新しく建てられた場合。
  • centralAir:家にエアコンがある場合。

目的は住宅価格を予測します。

data("SaratogaHouses", package = "mosaicData")
dat = SaratogaHouses

3.探索的分析

機械学習にかける前に、データの探索的分析を実行することが非常に重要です。
このプロセスにより、各説明変数に含まれる情報をよりよく理解することができます。
さらに、この分析により、モデルの予測にどの説明変数が適しているかについての手がかりが得られます。
パッケージskimr、DataExplorerおよびGGallyを利用すると、このタスクがはるかに簡単になります。

3.1要約表の作成

skim(dat)

#> ─ Data Summary ────────────
#>                           Values
#> Name                       dat   
#> Number of rows             1728  
#> Number of columns          16    
#> _______________________          
#> Column type frequency:           
#>   factor                   6     
#>   numeric                  10    
#> ________________________         
#> Group variables            None  
#> 
#> ─ Variable type: factor ──────────────────────────────────
#>   skim_variable   n_missing complete_rate ordered n_unique top_counts               #>     
#> 1 heating                 0             1 FALSE          3 hot: 1121, ele: 305, hot: 302
#> 2 fuel                    0             1 FALSE          3 gas: 1197, ele: 315, oil: 216
#> 3 sewer                   0             1 FALSE          3 pub: 1213, sep: 503, non: 12 
#> 4 waterfront              0             1 FALSE          2 No: 1713, Yes: 15            
#> 5 newConstruction         0             1 FALSE          2 No: 1647, Yes: 81            
#> 6 centralAir              0             1 FALSE          2 No: 1093, Yes: 635      
#> 
#> 
#> ─ Variable type: numeric ─────────────────────────────────
#>    skim_variable n_missing complete_rate       mean        sd    p0       p25       p50
#>  1 price                 0             1 211967.    98441.     5000 145000    189900   
#>  2 lotSize               0             1      0.500     0.699     0      0.17      0.37
#>  3 age                   0             1     27.9      29.2       0     13        19   
#>  4 landValue             0             1  34557.    35021.      200  15100     25000   
#>  5 livingArea            0             1   1755.      620.      616   1300      1634.  
#>  6 pctCollege            0             1     55.6      10.3      20     52        57   
#>  7 bedrooms              0             1      3.15      0.817     1      3         3   
#>  8 fireplaces            0             1      0.602     0.556     0      0         1   
#>  9 bathrooms             0             1      1.90      0.658     0      1.5       2   
#> 10 rooms                 0             1      7.04      2.32      2      5         7   
#>          p75     p100 hist 
#>  1 259000    775000   ▅▇▂▁▁
#>  2      0.54     12.2 ▇▁▁▁▁
#>  3     34       225   ▇▁▁▁▁
#>  4  40200    412600   ▇▁▁▁▁
#>  5   2138.     5228   ▇▇▂▁▁
#>  6     64        82   ▁▃▆▇▁
#>  7      4         7   ▃▇▅▁▁
#>  8      1         4   ▆▇▁▁▁
#>  9      2.5       4.5 ▁▇▇▁▁
#> 10      8.25     12   ▃▇▇▅▂
head(dat, 5)

01.png

3.2目的変数の分布

目的変数の分布を調べることが非常に重要です。
結局のところ、それが予測したいものだからです。
いくつかの家の価格が平均よりはるかに高いために、目的変数は正のテールを持つ非対称分布を持っています。
このタイプの分布は、通常、対数または平方根が適用されます。

p1 = ggplot(data = dat, aes(x = price)) +
      geom_density(fill = "steelblue", alpha = 0.8) +
      geom_rug(alpha = 0.1) +
      scale_x_continuous(labels = scales::comma) +
      labs(title = "オリジナルの分布") +
      theme_bw() 

p2 = ggplot(data = dat, aes(x = sqrt(price))) +
      geom_density(fill = "steelblue", alpha = 0.8) +
      geom_rug(alpha = 0.1) +
      scale_x_continuous(labels = scales::comma) +
      labs(title = "平方根変換") +
      theme_bw() 

p3 = ggplot(data = dat, aes(x = log(price))) +
      geom_density(fill = "steelblue", alpha = 0.8) +
      geom_rug(alpha = 0.1) +
      scale_x_continuous(labels = scales::comma) +
      labs(title = "対数変換") +
      theme_bw() 

ggarrange(p1, p2, p3, ncol = 1, align = "v")

02.png
目的変数の要約です。

summary(dat$price)

#>   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   5000  145000  189900  211967  259000  775000 

統計学習モデルでは、目的変数を特定の方法で必要に応じて変数変換する必要があります。
例えば、線形回帰(LM)モデルの場合は、分布は正規分布が求められます。
一般化線形モデル(GLM)の場合は、分布は指数型分布族が求められます。

Rには、データの最適な分布を特定できるunivariateMLというパッケージがあります。

comparacion_aic = AIC(
                    mlbetapr(dat$price),
                    mlexp(dat$price),
                    mlinvgamma(dat$price),
                    mlgamma(dat$price),
                    mllnorm(dat$price),
                    mlrayleigh(dat$price),
                    mlinvgauss(dat$price),
                    mlweibull(dat$price),
                    mlinvweibull(dat$price),
                    mllgamma(dat$price)
                   )
comparacion_aic %>% rownames_to_column(var = "distribucion") %>% arrange(AIC)

03.png
目的変数が従う確率分布はAICからガンマ分布と正規分布が最適です。

3.3説明変数の分布

plot_density(
  data    = dat %>% select(-price),
  ncol    = 3,
  title   = "説明変数(numeric型)の分布",
  ggtheme = theme_bw(),
  theme_config = list(
                  plot.title = element_text(size = 16, face = "bold"),
                  strip.text = element_text(colour = "black", size = 12, face = 2)
                 )
  )

image.png

説明変数"fireplaces"は、numeric型をとりますが、この場合はfactor型に変えます。

table(dat$fireplaces)

#>  0   1   2   3   4 
#> 740 942  42   2   2 
dat = dat %>%
         mutate(fireplaces = as.factor(fireplaces))

住宅価格を予測することが目的であるため、各説明変数の分析も目的変数に関連して行います。
このようにデータ分析することで、どの説明変数がどのように価格に関連しているかについてのアイデアを得ることができます。

説明変数("age","lotSize","landValue")は、右側に大きくテールが伸びた分布であるため、対数変換します。

dat = dat %>%
         mutate(
           log_age    = log(age + 1),
           log_lotSize = log(lotSize),
           log_landValue = log(landValue)
         )
custom_corr_plot = function(variable1, variable2, df, alpha=0.3){
  p = df %>%
       mutate(
          title = paste(toupper(variable2), "vs", toupper(variable1))
       ) %>%
       ggplot(aes(x = !!sym(variable1), y = !!sym(variable2))) + 
       geom_point(alpha = alpha) +
       # 加法モデル(GAM)
       geom_smooth(se = FALSE, method = "gam", formula =  y ~ splines::bs(x, 3)) +
       # 線形モデル(LM)
       geom_smooth(se = FALSE, method = "lm", color = "firebrick") +
       facet_grid(. ~ title) +
       theme_bw() +
       theme(strip.text = element_text(colour = "black", size = 8, face = 2),
             axis.title = element_blank())
  return(p)
}
variables_continuas = c("livingArea", "pctCollege", "bedrooms",
                         "bathrooms", "rooms", "log_age", "log_lotSize",
                         "log_landValue")

plots <- map(
            .x = variables_continuas,
            .f = custom_corr_plot,
            variable2 = "price",
            df = dat
         )

ggarrange(plotlist = plots, ncol = 3, nrow = 3) %>%
  annotate_figure(
    top = text_grob("住宅価格と説明変数との相関関係", face = "bold", size = 16,
                    x = 0.50)
  )

image.png

3.4説明変数間の相関関係

線形回帰分析は、説明変数間の相関関係が高いと予測に影響を与えます(多重共線性)。

plot_correlation(
  data = dat %>% select(-c("log_age","log_lotSize","log_landValue")),
  type = "continuous",
  title = "目的変数及び説明変数(連続値)の相関行列",
  theme_config = list(legend.position = "none",
                      plot.title = element_text(size = 12, face = "bold"),
                      axis.title = element_blank(),
                      axis.text.x = element_text(angle = -45, hjust = +0.1)
                     )
)

image.png

GGally::ggscatmat(
  data = dat %>% select_if(is.numeric) %>% select(-c("log_age","log_lotSize","log_landValue")),
  alpha = 0.1) +
theme_bw() +
labs(title = "相関行列") +
theme(
  plot.title = element_text(size = 16, face = "bold"),
  axis.text = element_blank(),
  strip.text = element_text(colour = "black", size = 5, face = 1)
)

image.png

3.5 質的変数(factor型)

plot_bar(
  dat,
  ncol    = 3,
  title   = "質的変数のカテゴリー別の観測値",
  ggtheme = theme_bw(),
  theme_config = list(
                   plot.title = element_text(size = 16, face = "bold"),
                   strip.text = element_text(colour = "black", size = 10, face = 2),
                   legend.position = "none"
                  )
)

image.png
質的変数のカテゴリー別の観測値の中に、属する観測値が極端に少ないカテゴリーが含まれる場合は、交差検証法等を行った場合におけるデータ分割時に、カテゴリーに属する値が0になってしまうため、不都合が生じる可能性があります。

この場合、次のように対処します。

・マルチカテゴリー変数の場合は、少数派カテゴリーの観測値を削除します。
・カテゴリーが2つしかない場合は、その変数を削除します。
・少数派のカテゴリーを1つのカテゴリーにグループ化します。

交差検証法等におけるデータを分割時には、すべてのカテゴリーの観測値が含まれるように注意してください。

5.5セクションの前処理で対応します。

custom_box_plot <- function(variable1, variable2, df, alpha=0.3){
  p <- df %>%
       mutate(
         # Trick to display the title in facet style
        title = paste(toupper(variable2), "vs", toupper(variable1))
       ) %>%
       ggplot(aes(x = !!sym(variable1), y = !!sym(variable2))) + 
       geom_violin(alpha = alpha) +
       geom_boxplot(width = 0.1, outlier.shape = NA) +
       facet_grid(. ~ title) +
       theme_bw() +
       theme(strip.text = element_text(colour = "black", size = 8, face = 2),
             axis.text.x = element_text(size = 7),
             axis.title = element_blank())
  return(p)
}
variables_cualitativas <- c("fireplaces", "heating", "fuel", "sewer",
                            "waterfront", "newConstruction", "centralAir")

plots <- map(
            .x = variables_cualitativas,
            .f = custom_box_plot,
            variable2 = "price",
            df = dat
         )

ggarrange(plotlist = plots, ncol = 3, nrow = 3) %>%
  annotate_figure(
    top = text_grob("住宅価格と質的変数との相関関係", face = "bold", size = 16,
                    x = 0.5)
  )

image.png

4.mlr3でモデルを作成する

訓練データからデータのパターンを学習し、適切なモデルを取得するための手順は次のとおりです。

4.1 調整/トレーニング

モデルが学習できるように、訓練データに機械学習アルゴリズムを適用します。

4.2 評価/検証

予測モデルの目標は、既知の観測値を予測できるようにすることではなく、モデルが認識していない新しい観測値を予測できるようにすることです。
交差検証法を用いて、モデルの精度を検証します。その中で、単純な交差検証、ブートストラップ法、があります。

4.3 ハイパーパラメータの最適化

多くの機械学習アルゴリズムでは、データで学習できない1つ以上のパラメータが含まれています。
これらは、ハイパーパラメータと呼ばれます。
たとえば、線形SVMにはコストハイパーパラメータCがあり、回帰決定木モデルーにはtree_depthの木の深さとノードあたりの最小観測数min_nがあります。
最良のモデルを生み出すハイパーパラメータの正確な値を事前に知る方法はないため、さまざまな値を比較するために交差検証法による検証が必要です。

4.3 予測

モデルができあがると、新しい観測値(テストデータ)を予測します。

mlr3が最も際立っているのはこのプロセス全体であり、同じ構文を使用して、アルゴリズムの名前を変えるだけで、さまざまなモデルを調整、最適化、評価、および予測をすることができます。
mlr3は、わずか数行のコードでこのすべてを可能になります。

mlr3開発者は、機械学習モデルを作成する段階を6つの「ブロック」に分割しました

image.png
図1 機械学習モデルを構築するブロック

これらの「ブロック」のそれぞれの中で、1つ以上のアクションが実行されます。

以下は、mlr3を使用して予測モデルを作成するために従う手順をまとめたリストです。

4.3.1 Taskの作成

Taskは、データ、変数のタイプ、観測の数、モデルのタイプ(回帰、分類など)に関する情報をカプセル化し、各変数の役割(予測子、応答、層別化、重みなど)を識別します。
使用可能なタスクのタイプは次のとおりです。

  • mlr3::TaskReg:回帰モデル。
  • mlr3::TaskClassif:2者またはマルチクラス分類モデル。
  • mlr3::TaskSurv:生存解析モデル

4.3.2 Learner(学習器)の作成

learnerは学習器を作成します。
使用する学習アルゴリズム、そのハイパーパラメータ、および予測のタイプを定義します。
また、データの前処理が指定されるのはこのオブジェクトです。

4.3.3 Resampleの作成

resampleはモデルの交差検証及び目的関数や損失関数等その予測能力を定量化するための方法を定義します。

4.3.4 Modelの学習

メソッド関数train()により、resample(),task,learner、およびオプションで適用するresampleによって、モデルが訓練され、その誤差を推定することができます。

4.3.5 Predection

メソッド関数predict()により、学習済みのモデルを使用して、新しいデータを予測します。


これらの「ブロック」のそれぞれの中で、1つ以上のアクションが実行されます。

5 MLR3による機械学習

5.1 Taskの作成

Taskは、データ、変数のタイプ、観測数、モデルのタイプ(回帰、分類など)に関する情報をカプセル化し、各変数の役割を識別します。

dat = dat %>% select(-c("log_age","log_landValue","log_lotSize"))
task_dat = TaskRegr$new(
                id      = "task_dat",
                backend = dat,
                target  = "price"
              )
task_dat

#> <TaskRegr:task_dat> (1728 x 16)
#> * Target: price
#> * Properties: -
#> * Features (15):
#>   - fct (7): centralAir, fireplaces, fuel, heating, newConstruction, sewer,
#>     waterfront
#>   - int (6): age, bedrooms, landValue, livingArea, pctCollege, rooms
#>   - dbl (2): bathrooms, lotSize

タスクに含まれる説明変数を表示します。

task_dat$col_info

image.png

task_dat$col_roles

#> $feature
#> [1] "age"             "bathrooms"       "bedrooms"        "centralAir"     
#> [5] "fireplaces"      "fuel"            "heating"         "landValue"      
#> [9] "livingArea"      "lotSize"         "newConstruction" "pctCollege"     
#> [13] "rooms"           "sewer"           "waterfront"     
#>
#> $target
#> [1] "price"
#> 
#> $name
#> character(0)
#> 
#> $order
#> character(0)
#> 
#> $stratum
#> character(0)
#> 
#> $group
#> character(0)
#> 
#> $weight
#> character(0)
#> 
#> $always_included
#> character(0)

taskのデータに欠測値があるかどうかを調べます

task_dat$missings()

#>          price             age       bathrooms        bedrooms      centralAir 
#>               0               0               0               0               0 
#>      fireplaces            fuel         heating       landValue      livingArea 
#>               0               0               0               0               0 
#>         lotSize newConstruction      pctCollege           rooms           sewer 
#>               0               0               0               0               0 
#>      waterfront 
#>               0 

質的変数に含まれるカテゴリーを表示します

task_dat$levels()

#> $centralAir
#> [1] "Yes" "No" 
#> 
#> $fireplaces
#> [1] "0" "1" "2" "3" "4"
#> 
#> $fuel
#> [1] "gas"      "electric" "oil"     
#> 
#> $heating
#> [1] "hot air"         "hot water/steam" "electric"       
#> 
#> $newConstruction
#> [1] "Yes" "No" 
#> 
#> $sewer
#> [1] "septic"            "public/commercial" "none"             
#> 
#> $waterfront
#> [1] "Yes" "No" 

taskに含まれるオリジナルデータを表示します。

task_dat$data()

image.png

5.2 学習及び評価

モデルの予測能力の評価は、その予測が目的変数の真の値にどれだけ近いかを評価することで構成されます。
それを正しく定量化できるようにするために、モデルを学習するデータとは別のデータ、つまり学習に関与していない新しいデータが必要です。
このために、利用可能なデータは訓練データとテストデータに分けられます。
分割する割合は、通常、80%程度で良好な結果が得られます。
ランダムまたはランダムに階層化された方法で分割します。

set.seed(1234)
rsmp_holdout = rsmp("holdout", ratio = 0.8)
rsmp_holdout

#> <ResamplingHoldout> with 1 iterations
#> * Instantiated: FALSE
#> * Parameters: ratio=0.8
rsmp_holdout$instantiate(task = task_dat)

#iはサンプリング開始する行番号を指定
id_train = rsmp_holdout$train_set(i = 1)
id_test = rsmp_holdout$test_set(i = 1)

#データをコピーするときはcloneメソッドを用いる
task_train = task_dat$clone()$filter(id_train)
task_test  = task_dat$clone()$filter(id_test)

目的変数の分布が訓練データとテストデータで大きく異なっていないかを確認することが重要です。

summary(task_train$data()[["price"]])

#> Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   10300  145000  189000  211594  256838  775000 
summary(task_test$data()[["price"]])

 #> Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
 #>   5000  142250  194000  213455  260000  760000

5.3 mlr3pipelinesを用いた前処理

前処理には、機械学習アルゴリズムで解釈できるように、データに対して実行されるすべての変換が含まれます。
データの前処理は、訓練データ及びテストデータの両者に適用する必要があります。
最も一般的に必要な前処理手順のいくつかを次に示します。

5.3.1 欠測値の扱い

アルゴリズムの大部分は欠測値を受け入れないため、データセットに欠測値が含まれている場合は、適切な対応が必要です。

  • 欠測値の除去

  • 欠測値を含む説明変数の削除

  • 利用可能な残りのデータを活用して欠測値を推定する

最初の2つのオプションは単純ですが、情報を失うことを伴います。
欠測値の削除は多くが利用可能で、欠測値の割合が非常に低い場合にのみ適用できます。

欠測値を含む説明変数を削除する場合、影響はこれらの変数がモデルに寄与する割合によって異なります。
欠測値の推定を使用する場合、モデルに大きな影響を与える説明変数に新たな値を入力するリスクを考慮することが非常に重要です。

5.3.2 説明変数の中で、ほぼ単一の値をとり、分散が0に近い場合

ほぼ単一の値をとり、分散が0に近い場合は、情報を提供しないため、モデルに含めるべきではありません。
また、ごくまれにしか表示されない値をいくつか取る説明変数を含めることもあまりよくありません。
クロスバリデーションやブートストラップによってデータが分割されたときに、単一の情報しかもたない説明変数になる可能性があります。

情報量の少ない説明変数の削除は、モデル構築のための適切なステップと見なすことができますが、データを標準化する前に実行する必要があります。

5.3.3 説明変数の標準化とスケーリング

説明変数が数値である場合、それらの観測値におけるスケール、や分散の大きさの違いは、モデルに大きな影響を与える可能性があります。多くの機械学習アルゴリズムは、これに敏感であるため、それを回避するために次のことをします。

(1) センタリング:各値から、それが属する説明変数の平均を差し引くことで構成されます。データがデータフレームに格納されている場合、センタリングは、各値からデータが配置されている列の平均を差し引くことによって実現されます。この変換の結果、すべての予測子の平均はゼロになります。つまり、値は原点を中心に配置されます。

(2) 正規化(標準化):すべての説明変数がほぼ同じスケールになるようにデータを変換することで構成されます。これを行うには2つの方法があります。

  • Zスコアの正規化:センタリングの後、その標準偏差で割ります。このようにして、データは正規分布になります。

  • 最大-最小標準化:[0、1]の範囲内になるようにデータを変換します。

(3) 質的変数の二値化
二値化は、新しいダミー変数を作成することで構成されます。
このプロセスは、ワンホットエンコーディングとも呼ばれます。例えば、red、green、blueを含むcolorという名前の変数は、3つの新しい変数(color_red、color_green、color_blue)に変換されます。

mlr3で前処理する方法は、パッケージmlr3pipelines、mlrfiltersを用います。

# 前処理をするオペレータの種類
mlr_pipeops$keys()

#>  [1] "adas"                  "blsmote"               "boxcox"               
#>  [4] "branch"                "chunk"                 "classbalancing"       
#>  [7] "classifavg"            "classweights"          "colapply"             
#> [10] "collapsefactors"       "colroles"              "copy"                 
#> [13] "datefeatures"          "encode"                "encodeimpact"         
#> [16] "encodelmer"            "featureunion"          "filter"               
#> [19] "fixfactors"            "histbin"               "ica"                  
#> [22] "imputeconstant"        "imputehist"            "imputelearner"        
#> [25] "imputemean"            "imputemedian"          "imputemode"           
#> [28] "imputeoor"             "imputesample"          "kernelpca"            
#> [31] "learner"               "learner_cv"            "missind"              
#> [34] "modelmatrix"           "multiplicityexply"     "multiplicityimply"    
#> [37] "mutate"                "nmf"                   "nop"                  
#> [40] "ovrsplit"              "ovrunite"              "pca"                  
#> [43] "proxy"                 "quantilebin"           "randomprojection"     
#> [46] "randomresponse"        "regravg"               "removeconstants"      
#> [49] "renamecolumns"         "replicate"             "rowapply"             
#> [52] "scale"                 "scalemaxabs"           "scalerange"           
#> [55] "select"                "smote"                 "smotenc"              
#> [58] "spatialsign"           "subsample"             "targetinvert"         
#> [61] "targetmutate"          "targettrafoscalerange" "textvectorizer"       
#> [64] "threshold"             "tunethreshold"         "unbranch"             
#> [67] "vtreat"                "yeojohnson"       
# フィルターの種類
mlr_filters$keys()

#>  [1] "anova"             "auc"               "boruta"           
#>  [4] "carscore"          "carsurvscore"      "cmim"             
#>  [7] "correlation"       "disr"              "find_correlation" 
#> [10] "importance"        "information_gain"  "jmi"              
#> [13] "jmim"              "kruskal_test"      "mim"              
#> [16] "mrmr"              "njmim"             "performance"      
#> [19] "permutation"       "relief"            "selected_features"
#> [22] "univariate_cox"    "variance"          

5.4 Learner(学習器)の構築

訓練データを定義した後の次のステップは、使用するアルゴリズムを選択することです。
learnerを使って構築します。具体的には、アルゴリズムの名前、そのパラメーターとハイパーパラメーター、および新しい観測値を予測するときに表示される結果のタイプを格納します。

使用することができるアルゴリズム一覧を表示します。
注)学習器の数は拡張しています

mlr_learners$keys()
       
#>   [1] "classif.abess"                   "classif.AdaBoostM1"             
#>   [3] "classif.bart"                    "classif.bayes_net"              
#>   [5] "classif.C50"                     "classif.catboost"               
#>   [7] "classif.cforest"                 "classif.ctree"                  
#>   [9] "classif.cv_glmnet"               "classif.debug"                  
#>  [11] "classif.decision_stump"          "classif.decision_table"         
#>  [13] "classif.earth"                   "classif.featureless"            
#>  [15] "classif.fnn"                     "classif.gam"                    
#>  [17] "classif.gamboost"                "classif.gausspr"                
#>  [19] "classif.gbm"                     "classif.glmboost"               
#>  [21] "classif.glmer"                   "classif.glmnet"                 
#>  [23] "classif.IBk"                     "classif.imbalanced_rfsrc"       
#>  [25] "classif.J48"                     "classif.JRip"                   
#>  [27] "classif.kknn"                    "classif.kstar"                  
#>  [29] "classif.ksvm"                    "classif.lda"                    
#>  [31] "classif.liblinear"               "classif.lightgbm"               
#>  [33] "classif.LMT"                     "classif.log_reg"                
#>  [35] "classif.logistic"                "classif.lssvm"                  
#>  [37] "classif.mob"                     "classif.multilayer_perceptron"  
#>  [39] "classif.multinom"                "classif.naive_bayes"            
#>  [41] "classif.naive_bayes_multinomial" "classif.naive_bayes_weka"       
#>  [43] "classif.nnet"                    "classif.OneR"                   
#>  [45] "classif.PART"                    "classif.priority_lasso"         
#>  [47] "classif.qda"                     "classif.random_forest_weka"     
#>  [49] "classif.random_tree"             "classif.randomForest"           
#>  [51] "classif.ranger"                  "classif.reptree"                
#>  [53] "classif.rfsrc"                   "classif.rpart"                  
#>  [55] "classif.rpf"                     "classif.sgd"                    
#>  [57] "classif.simple_logistic"         "classif.smo"                    
#>  [59] "classif.svm"                     "classif.voted_perceptron"       
#>  [61] "classif.xgboost"                 "clust.agnes"                    
#>  [63] "clust.ap"                        "clust.bico"                     
#>  [65] "clust.birch"                     "clust.cmeans"                   
#>  [67] "clust.cobweb"                    "clust.dbscan"                   
#>  [69] "clust.dbscan_fpc"                "clust.diana"                    
#>  [71] "clust.em"                        "clust.fanny"                    
#>  [73] "clust.featureless"               "clust.ff"                       
#>  [75] "clust.hclust"                    "clust.hdbscan"                  
#>  [77] "clust.kkmeans"                   "clust.kmeans"                   
#>  [79] "clust.MBatchKMeans"              "clust.mclust"                   
#>  [81] "clust.meanshift"                 "clust.optics"                   
#>  [83] "clust.pam"                       "clust.SimpleKMeans"             
#>  [85] "clust.xmeans"                    "dens.kde_ks"                    
#>  [87] "dens.locfit"                     "dens.logspline"                 
#>  [89] "dens.mixed"                      "dens.nonpar"                    
#>  [91] "dens.pen"                        "dens.plug"                      
#>  [93] "dens.spline"                     "regr.abess"                     
#>  [95] "regr.bart"                       "regr.catboost"                  
#>  [97] "regr.cforest"                    "regr.ctree"                     
#>  [99] "regr.cubist"                     "regr.cv_glmnet"                 
#> [101] "regr.debug"                      "regr.decision_stump"            
#> [103] "regr.decision_table"             "regr.earth"                     
#> [105] "regr.featureless"                "regr.fnn"                       
#> [107] "regr.gam"                        "regr.gamboost"                  
#> [109] "regr.gaussian_processes"         "regr.gausspr"                   
#> [111] "regr.gbm"                        "regr.glm"                       
#> [113] "regr.glmboost"                   "regr.glmnet"                    
#> [115] "regr.IBk"                        "regr.kknn"                      
#> [117] "regr.km"                         "regr.kstar"                     
#> [119] "regr.ksvm"                       "regr.liblinear"                 
#> [121] "regr.lightgbm"                   "regr.linear_regression"         
#> [123] "regr.lm"                         "regr.lmer"                      
#> [125] "regr.m5p"                        "regr.M5Rules"                   
#> [127] "regr.mars"                       "regr.mob"                       
#> [129] "regr.multilayer_perceptron"      "regr.nnet"                      
#> [131] "regr.priority_lasso"             "regr.random_forest_weka"        
#> [133] "regr.random_tree"                "regr.randomForest"              
#> [135] "regr.ranger"                     "regr.reptree"                   
#> [137] "regr.rfsrc"                      "regr.rpart"                     
#> [139] "regr.rpf"                        "regr.rsm"                       
#> [141] "regr.rvm"                        "regr.sgd"                       
#> [143] "regr.simple_linear_regression"   "regr.smo_reg"                   
#> [145] "regr.svm"                        "regr.xgboost"                   
#> [147] "surv.akritas"                    "surv.aorsf"                     
#> [149] "surv.bart"                       "surv.blackboost"                
#> [151] "surv.cforest"                    "surv.coxboost"                  
#> [153] "surv.coxtime"                    "surv.ctree"                     
#> [155] "surv.cv_coxboost"                "surv.cv_glmnet"                 
#> [157] "surv.deephit"                    "surv.deepsurv"                  
#> [159] "surv.dnnsurv"                    "surv.flexible"                  
#> [161] "surv.gamboost"                   "surv.gbm"                       
#> [163] "surv.glmboost"                   "surv.glmnet"                    
#> [165] "surv.loghaz"                     "surv.mboost"                    
#> [167] "surv.nelson"                     "surv.parametric"                
#> [169] "surv.pchazard"                   "surv.penalized"                 
#> [171] "surv.priority_lasso"             "surv.ranger"                    
#> [173] "surv.rfsrc"                      "surv.svm"                       
[#> 175] "surv.xgboost.aft"                "surv.xgboost.cox"         

Learner(学習器)を作成します。

learner_svm = lrn("regr.svm")
learner_svm

#> <LearnerRegrSVM:regr.svm>: Support Vector Machine
#> * Model: -
#> * Parameters: list()
#> * Packages: mlr3, mlr3learners, e1071
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric
#> * Properties: -
  • Packages:アルゴリズムを用いているパッケージ名
  • Predict_types:予測のタイプを示します。
    分類:「response」は最も確率の高いクラスのみを返し、「prob」は各クラスの確率も返します。
    回帰:「response」は予測値を返し、「se」は予測の標準誤差を返します。
  • Feature_types:アルゴリズムが受け入れることができる変数のタイプ
    (要因を処理できない変数を識別するために重要です。
  • Properties:アルゴリズムの追加プロパティ。
    例えば、「missings」は欠測値を処理できることを意味し、「importance」はアルゴリズムが説明変数の重要度を計算できることを意味します。
learner_svm$param_set

image.png
$param_set でパラメーターの有効な値の範囲、およびデフォルト値の詳細が見れます。

さらに'levels'の詳細を見ます。

learner_svm$param_set$levels %>% purrr::discard(is.null)

#> $fitted
#> [1]  TRUE FALSE
#> 
#> $kernel
#> [1] "linear"     "polynomial" "radial"     "sigmoid"   
#> 
#> $shrinking
#> [1]  TRUE FALSE
#> 
#> $type
#> [1] "eps-regression" "nu-regression" 

学習器の作成時に、パラメーターを設定することもできます。

learner_svm = lrn("regr.svm", type = "eps-regression", kernel = "linear", cost = 0.5)
learner_svm

#> <LearnerRegrSVM:regr.svm>: Support Vector Machine
#> * Model: -
#> * Parameters: cost=0.5, kernel=linear, type=eps-regression
#> * Packages: mlr3, mlr3learners, e1071
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric
#> * Properties: -

パラメータに加えて、変更できるもう1つの重要な機能predict_typeは、予測で返される結果のタイプを決定します。

learner_svm$predict_type = "response"

5.5 前処理の設定

mlr3pipelinesは、ppl("robustify")と呼ばれるシンプルで再利用可能なパイプラインを提供しています。
このパイプラインには、以下のPipeOpsが全て含まれています(いくつかは複数回適用され、ほとんどはセレクタを使用します):

  • po("removeconstants") - 定数特徴が削除されます。
  • po("colapply") - 文字および順序特徴がカテゴリとしてエンコードされ、日付/時刻特徴が数値としてエンコードされます。
  • po("imputehist") - 数値特徴がヒストグラムサンプリングによってインプットされます。
  • po("imputesample") - 論理特徴が経験分布からのサンプリングによってインプットされます。
  • po("missind") - 欠損データのインジケータが、インポートされた数値および論理変数に追加されます。
  • po("imputeoor") - カテゴリ特徴量の欠落値が、新しいレベルでエンコードされます。
  • po("fixfactors") - 予測およびトレーニング中に同じレベルが存在するように、カテゴリ特徴量のレベルを修正します (これは、空の要素レベルを削除することを含む場合があります)。
  • po("imputesample") - 前のステップでレベルを削除することで導入されたカテゴリ特徴量の欠損値を、経験分布からサンプリングしてインプットします。
  • po("collapsefactors") - max_cardinality引数(デフォルトは1000)で制御されるレベル数未満になるまで、カテゴリ特徴量のレベルが(トレーニングデータで最もレアな因子から)折りたたまれます。
  • po("encode") - カテゴリ特徴量がワンホットエンコーディングされます。
  • po("removeconstants") - 前のステップで作成された可能性のある定数特徴量が削除されます。

注)欠測値の推定からワンホットまで基本的な前処理が全て入っている優れものです。

pos = ppl("robustify")
pos$train(task_train)

#> $removeconstants_postrobustify.output
#> <TaskRegr:task_dat> (1382 x 29)
#> * Target: price
#> * Properties: -
#> * Features (28):
#>   - dbl (22): bathrooms, centralAir.No, centralAir.Yes,
#>     fireplaces.0, fireplaces.1, fireplaces.2, fireplaces.3,
#>     fireplaces.4, fuel.electric, fuel.gas, fuel.oil,
#>     heating.electric, heating.hot.air, heating.hot.water.steam,
#>     lotSize, newConstruction.No, newConstruction.Yes, sewer.none,
#>     sewer.public.commercial, sewer.septic, waterfront.No,
#>     waterfront.Yes
#>   - int (6): age, bedrooms, landValue, livingArea, pctCollege, rooms

factor因子が全てワンホットエンコーディングされ特徴量は29まで増えています。

5.6 データの学習

taskとlearnerが作成できると、前処理と組み合わせ、モデルを訓練データで学習させます。
すべての説明変数を用いて、家の価格を予測するサポートベクトルマシンモデル(SVM)に進みます。
学習は、taskに格納されている訓練データを使用します。

learner_svm = as_learner(ppl("robustify") %>>% lrn("regr.svm"))
learner_svm$train(task_train)

5.7 交差検証

モデルの最終的な目的は、将来のデータまたはモデルが以前に「見たことがない」データにおける目的変数を予測することです。
モデルのトレーニング後に表示される誤差は、通常、訓練誤差です。これはモデルがすでに「見た」データを予測する際に発生する誤差です。
これらの誤差は新しいデータに直面したときのモデルの動作を現実的に見積もることはできません。
より正確な推定を達成するために、リサンプリングに基づく交差検証法があります。
mlr3には、次のrリサンプリング戦略が組み込まれています。

  • Bootstrap(mlr_resamplings_bootstrap)
  • V-Fold Cross-Validation(mlr_resamplings_cv)
  • Repeated V-Fold Cross-Validation(mlr_resamplings_repeated_cv)
  • Monte Carlo Cross-Validation(mlr_resamplings_subsampling)
    など

それぞれが内部で異なる動作をしますが、それらはすべて、訓練データを元に作成した様々なサブセットデータを毎回使用し、各繰り返しの誤差の推定値を取得することで、モデルを評価するという考えに基づいています。

mlr_resamplings$keys()

#> [1] "bootstrap"   "custom"      "custom_cv"   "cv"          "holdout"    
#> [6] "insample"    "loo"         "repeated_cv" "subsampling"
set.seed(123)
resampling_cv = rsmp("cv", folds = 5)
resampling_cv

#> <ResamplingCV>: Cross-Validation
#> * Iterations: 5
#> * Instantiated: FALSE
#> * Parameters: folds=5

K-folds クロスバリデーションを選びます。5分割します。
この次はモデルを評価する目的関数を定義します。

mlr_measures$keys()

#>  [1] "aic"                  "bic"                  "classif.acc"         
#>  [4] "classif.auc"          "classif.bacc"         "classif.bbrier"      
#>  [7] "classif.ce"           "classif.costs"        "classif.dor"         
#> [10] "classif.fbeta"        "classif.fdr"          "classif.fn"          
#> [13] "classif.fnr"          "classif.fomr"         "classif.fp"          
#> [16] "classif.fpr"          "classif.logloss"      "classif.mauc_au1p"   
#> [19] "classif.mauc_au1u"    "classif.mauc_aunp"    "classif.mauc_aunu"   
#> [22] "classif.mauc_mu"      "classif.mbrier"       "classif.mcc"         
#> [25] "classif.npv"          "classif.ppv"          "classif.prauc"       
#> [28] "classif.precision"    "classif.recall"       "classif.sensitivity" 
#> [31] "classif.specificity"  "classif.tn"           "classif.tnr"         
#> [34] "classif.tp"           "classif.tpr"          "clust.ch"            
#> [37] "clust.dunn"           "clust.silhouette"     "clust.wss"           
#> [40] "debug_classif"        "internal_valid_score" "oob_error"           
#> [43] "regr.bias"            "regr.ktau"            "regr.mae"            
#> [46] "regr.mape"            "regr.maxae"           "regr.medae"          
#> [49] "regr.medse"           "regr.mse"             "regr.msle"           
#> [52] "regr.pbias"           "regr.pinball"         "regr.rae"            
#> [55] "regr.rmse"            "regr.rmsle"           "regr.rrse"           
#> [58] "regr.rse"             "regr.rsq"             "regr.sae"            
#> [61] "regr.smape"           "regr.srho"            "regr.sse"            
#> [64] "selected_features"    "sim.jaccard"          "sim.phi"             
#> [67] "time_both"            "time_predict"         "time_train"

RMSE(二乗平均平方根誤差)を目的関数にします。

metric = msr("regr.rmse")
result_resampling = resample(
                           task         = task_train,
                           learner      = learner_svm,
                           resampling   = resampling_cv,
                           store_models = FALSE
                         )  
result_resampling

#> INFO  [14:24:09.415] [mlr3] Applying learner 
#> 'removeconstants_prerobustify.char_to_fct.POSIXct_to_dbl.ord_to_fct.imputehist.missind.impute_logicals.featureunion_robustify.imputeoor.fixfactors.imputesample.collapsefactors.encode.removeconstants_postrobustify.regr.svm' on task 'task_dat' (iter 1/5)
#> INFO  [14:24:10.220] [mlr3] Applying learner 'removeconstants_prerobustify.char_to_fct.POSIXct_to_dbl.ord_to_fct.imputehist.missind.impute_logicals.featureunion_robustify.imputeoor.fixfactors.imputesample.collapsefactors.encode.removeconstants_postrobustify.regr.svm' on task 'task_dat' (iter 2/5)
#> INFO  [14:24:10.804] [mlr3] Applying learner 'removeconstants_prerobustify.char_to_fct.POSIXct_to_dbl.ord_to_fct.imputehist.missind.impute_logicals.featureunion_robustify.imputeoor.fixfactors.imputesample.collapsefactors.encode.removeconstants_postrobustify.regr.svm' on task 'task_dat' (iter 3/5)
#> INFO  [14:24:11.379] [mlr3] Applying learner 'removeconstants_prerobustify.char_to_fct.POSIXct_to_dbl.ord_to_fct.imputehist.missind.impute_logicals.featureunion_robustify.imputeoor.fixfactors.imputesample.collapsefactors.encode.removeconstants_postrobustify.regr.svm' on task 'task_dat' (iter 4/5)
#> INFO  [14:24:11.968] [mlr3] Applying learner 'removeconstants_prerobustify.char_to_fct.POSIXct_to_dbl.ord_to_fct.imputehist.missind.impute_logicals.featureunion_robustify.imputeoor.fixfactors.imputesample.collapsefactors.encode.removeconstants_postrobustify.regr.svm' on task 'task_dat' (iter 5/5)
#> <ResampleResult> with 5 resampling iterations

回帰分析の場合、デフォルトはmseになっています。5回のiterationの結果の平均です。

result_resampling$aggregate(measures = metric)

#> regr.rmse 
#>  61275.52 

iteration毎のrmseを表示します。

result_resampling$score(msr("regr.rmse"))

image.png

バリデーションの結果の分布を示します

llibrary(mlr3viz)
autoplot(result_resampling, measure = metric)

image.png

バリデーションの予測を保存すると、モデルの真値と予測値との残差を評価できるようになります。

pred_valid = result_resampling$prediction() %>%
                           as.data.table() %>%
                           mutate(
                             residual= response - truth
                           )

p1 = ggplot(data = pred_valid, aes(x = truth, y = response)) +
      geom_point(alpha = 0.3) +
      geom_abline(slope = 1, intercept = 0, color = "firebrick") +
      labs(title = "Predicted value vs. actual value") +
      theme_bw()


p2 = ggplot(data = pred_valid, aes(x = row_ids, y = residual)) +
      geom_point(alpha = 0.3) +
      geom_hline(yintercept =  0, color = "firebrick") +
      labs(title = "Model residual") +
      theme_bw()

p3 = ggplot(data = pred_valid, aes(x = residual)) +
      geom_density() + 
      labs(title = "Distribution of model residual") +
      theme_bw()

p4 = ggplot(data = pred_valid, aes(sample = residual)) +
      geom_qq() +
      geom_qq_line(color = "firebrick") +
      labs(title = "Q-Q model residual") +
      theme_bw()

ggarrange(plotlist = list(p1, p2, p3, p4)) %>%
annotate_figure(
  top = text_grob("モデルの残差分布", size = 15, face = "bold")
)

image.png

GraphLearnerによるpipline処理を図示化します。

learner_svm$graph$plot()

image.png

GraphLearnerは、inputからoutputまでの一連のアクションを前処理として実行します。
GraphLearnerが作成されると、learnerそれが1つであるかのように使用できます。

5.8 テストデータを用いた予測

Learnerを使用してモデルをトレーニングすると、$predict()メソッドを用いて、新しいデータの予測ができます。

pred = learner_svm$predict(
                  task = task_test
                )
as.data.table(pred) %>% head(5)

image.png

テストデータのrmseの計算に進みます。

pred$score(measures = msr("regr.rmse"))

#> regr.rmse 
#>  61000.09 

クロスバリデーションの結果は 61275.52 テストデータは61000.09 と差はほとんどない結果となりました。

6.参考

Applied Machine Learning Using mlr3 in R

7.enjoy

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?