2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

mlr3で「Rによる機械学習」:第6章 「数値データの予測 --- 回帰法」

Last updated at Posted at 2021-08-31

準備

必要なパッケージをインストールしておきます.

install.packages("rpart.plot")
install.packages("GGally")
library(mlr3verse)
library(mlr3extralearners)
library(tidyverse)
library(lubridate)
library(rpart.plot)
library(GGally)

6.1 回帰とは何か

Dalal, Flowlkes and Hoadley (1989)から,スペースシャトルの発射時の気温とOリングの危機的事象の数のデータを入力します.

launch <- tribble(
  ~Flight, ~Date,      ~distress, ~temperature, ~field_check_pressure,
  "1",     "04/12/81", 0,         66,           50,
  "2",     "11/12/81", 1,         70,           50,
  "3",     "03/22/82", 0,         69,           50,
  "5",     "11/11/82", 0,         68,           50,
  "6",     "04/04/83", 0,         67,           50,
  "7",     "06/18/83", 0,         72,           50,
  "8",     "08/30/83", 0,         73,           100,
  "9",     "11/28/83", 0,         70,           100,
  "41-B",  "02/03/84", 1,         57,           200,
  "41-C",  "04/06/84", 1,         63,           200,
  "41-D",  "08/30/84", 1,         70,           200,
  "41-G",  "10/05/84", 0,         78,           200,
  "51-A",  "11/08/84", 0,         67,           200,
  "51-C",  "01/24/85", 2,         53,           200,
  "51-D",  "04/12/85", 0,         67,           200,
  "51-B",  "04/29/85", 0,         75,           200,
  "51-G",  "06/17/85", 0,         70,           200,
  "51-F",  "07/29/85", 0,         81,           200,
  "51-I",  "07/27/85", 0,         76,           200,
  "51-J",  "10/03/85", 0,         79,           200,
  "61-A",  "10/30/85", 2,         75,           200,
  "61-B",  "11/26/85", 0,         76,           200,
  "61-C",  "01/12/86", 1,         58,           200
) %>% 
  mutate(Date = mdy(Date)) %>%
  mutate(flight_num = row_number())
launch
## # A tibble: 23 x 6
##    Flight Date       distress temperature field_check_pressure flight_num
##    <chr>  <date>        <dbl>       <dbl>                <dbl>      <int>
##  1 1      1981-04-12        0          66                   50          1
##  2 2      1981-11-12        1          70                   50          2
##  3 3      1982-03-22        0          69                   50          3
##  4 5      1982-11-11        0          68                   50          4
##  5 6      1983-04-04        0          67                   50          5
##  6 7      1983-06-18        0          72                   50          6
##  7 8      1983-08-30        0          73                  100          7
##  8 9      1983-11-28        0          70                  100          8
##  9 41-B   1984-02-03        1          57                  200          9
## 10 41-C   1984-04-06        1          63                  200         10
## # ... with 13 more rows

散布図を描きます.

g <- ggplot(launch, aes(temperature, distress)) + 
  geom_point() +
  xlab("Temperature") + ylab("Distress")
print(g)

download.png

回帰直線を描き込みます.

g <- g + stat_smooth(method = "lm", se = FALSE)
print(g)

download.png

回帰係数を推計します.

b <- cov(launch$temperature, launch$distress) / var(launch$temperature)
print(b)
## [1] -0.04753968
a <- mean(launch$distress) - b * mean(launch$temperature)
print(a)
## [1] 3.698413

発射時の気温とOリングの危機的事象の数の相関を計算します.

r <- cov(launch$temperature, launch$distress) / 
  (sd(launch$temperature) * sd(launch$distress))
print(r)
## [1] -0.5111264

独立変数と従属変数を引数として取り,回帰係数のベクトルを返す関数を定義します.

reg <- function(y, x) {
  x <- as.matrix(x)
  x <- cbind(intercept = 1, x)
  b <- solve(t(x) %*% x) %*% t(x) %*% y
  colnames(b) <- "estimate"
  print(b)
}

この関数を使って気温とOリングの危機的事象の数の単純線形回帰係数を求めてみます.

reg(y = launch$distress, x = launch[4])
##                estimate
## intercept    3.69841270
## temperature -0.04753968

多重線形回帰の回帰係数を計算します.

reg(y = launch$distress, x = launch[4:6])
##                          estimate
## intercept             3.527093383
## temperature          -0.051385940
## field_check_pressure  0.001757009
## flight_num            0.014292843

Rに用意されている関数を使って回帰分析

せっかくですのでRに用意されている関数を用いて,もう少しきちんと回帰分析をしてみましょう.

cor()を用いると説明変数間の関係を表す相関行列を表示することができます.

cor(launch[
  c("distress",
    "temperature",
    "field_check_pressure",
    "flight_num")
  ])
##                        distress temperature field_check_pressure flight_num
## distress              1.0000000 -0.51112639           0.28466627  0.1735779
## temperature          -0.5111264  1.00000000           0.03981769  0.2307702
## field_check_pressure  0.2846663  0.03981769           1.00000000  0.8399324
## flight_num            0.1735779  0.23077017           0.83993237  1.0000000

GGallyパッケージに含まれるggpairs()を用いると説明変数間の関係を図示することができます.

g_pairs <- launch %>%
  select(distress, temperature, field_check_pressure, flight_num) %>%
  ggpairs()
print(g_pairs)

download.png

回帰分析を行うためにはlm()関数を用います.引数として~の左側に被説明変数,右側に説明変数を書きます.

ここではまず単回帰を行ってみます.計算結果が上で求めたreg(y = launch$distress, x = launch[4])の結果と一致していることがわかります.

distress_lm <- lm(distress ~ temperature, data = launch)
print(distress_lm)
## 
## Call:
## lm(formula = distress ~ temperature, data = launch)
## 
## Coefficients:
## (Intercept)  temperature  
##     3.69841     -0.04754

次に説明変数をすべて使った重回帰です.こちらの計算結果はreg(y = launch$distress, x = launch[4:6])に対応します.

distress_lm_all <- 
  lm(distress ~ temperature + field_check_pressure + flight_num,
     data = launch)
print(distress_lm_all)
## 
## Call:
## lm(formula = distress ~ temperature + field_check_pressure + 
##     flight_num, data = launch)
## 
## Coefficients:
##          (Intercept)           temperature  field_check_pressure  
##             3.527093             -0.051386              0.001757  
##           flight_num  
##             0.014293

結果の詳細を知るためにはsummary()を使います.

print(summary(distress_lm_all))
## 
## Call:
## lm(formula = distress ~ temperature + field_check_pressure + 
##     flight_num, data = launch)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.65003 -0.24414 -0.11219  0.01279  1.67530 
## 
## Coefficients:
##                       Estimate Std. Error t value Pr(>|t|)  
## (Intercept)           3.527093   1.307024   2.699   0.0142 *
## temperature          -0.051386   0.018341  -2.802   0.0114 *
## field_check_pressure  0.001757   0.003402   0.517   0.6115  
## flight_num            0.014293   0.035138   0.407   0.6887  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.565 on 19 degrees of freedom
## Multiple R-squared:   0.36,  Adjusted R-squared:  0.259 
## F-statistic: 3.563 on 3 and 19 DF,  p-value: 0.03371

かろうじて5%有意水準でtemperatureのみ帰無仮説を棄却することができることがわかります.

lmr3で回帰分析

lmr3パッケージにはlm()を呼び出すための学習器であるregr.lmがありますので,ついでにこれを使って回帰分析してみましょう.

learner_lm <- lrn("regr.lm")
task <- launch %>%
  select(-Date) %>%
  as_task_regr(target = "distress")
task$set_col_roles("Flight", roles = "name")
print(task)
## <TaskRegr:.> (23 x 4)
## * Target: distress
## * Properties: -
## * Features (3):
##   - dbl (2): field_check_pressure, temperature
##   - int (1): flight_num

実際にデータが含まれていることを確認します.

print(task$data())
##     distress field_check_pressure flight_num temperature
##  1:        0                   50          1          66
##  2:        1                   50          2          70
##  3:        0                   50          3          69
##  4:        0                   50          4          68
##  5:        0                   50          5          67
##  6:        0                   50          6          72
##  7:        0                  100          7          73
##  8:        0                  100          8          70
##  9:        1                  200          9          57
## 10:        1                  200         10          63
## 11:        1                  200         11          70
## 12:        0                  200         12          78
## 13:        0                  200         13          67
## 14:        2                  200         14          53
## 15:        0                  200         15          67
## 16:        0                  200         16          75
## 17:        0                  200         17          70
## 18:        0                  200         18          81
## 19:        0                  200         19          76
## 20:        0                  200         20          79
## 21:        2                  200         21          75
## 22:        0                  200         22          76
## 23:        1                  200         23          58
##     distress field_check_pressure flight_num temperature

いったんタスクにしてしまえばペアプロットも簡単に描けます.

autoplot(task, type = "pairs")

download.png

リサンプリングは全部を訓練データにしてしまえばOKです.

resampling <- rsmp("holdout", ratio = 1)
rr <- resample(task, learner_lm, resampling, store_models = TRUE)
## INFO  [10:03:40.607] [mlr3]  Applying learner 'regr.lm' on task '.' (iter 1/1)

当たり前ですが計算結果は一致していますね.

print(rr$learners[[1]]$model)
## 
## Call:
## stats::lm(formula = task$formula(), data = task$data())
## 
## Coefficients:
##          (Intercept)  field_check_pressure            flight_num  
##             3.527093              0.001757              0.014293  
##          temperature  
##            -0.051386

6.2 実例 --- 線形回帰を使った医療費の予測

データがないため省略.

6.3 回帰木とモデル木とは何か

例を用いて標準偏差減少(Standard Deviation Reduction:SDR)について計算してみます.

tee <- c(1, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7)
at1 <- tee[1:9]
at2 <- tee[10:15]
bt1 <- tee[1:7]
bt2 <- tee[8:15]
sdr_a <- 
  sd(tee) - 
  (length(at1) / length(tee) * sd(at1) + length(at2) / length(tee) * sd(at2))
sdr_b <- 
  sd(tee) -
  (length(bt1) / length(tee) * sd(bt1) + length(bt2) / length(tee) * sd(bt2))
print(paste0("sdr_a: ", sdr_a))
## [1] "sdr_a: 1.20281456809792"
print(paste0("sdr_b: ", sdr_b))
## [1] "sdr_b: 1.39275139353039"

bt*による分割の方が標準偏差が減少していることがわかります.

6.4 実例 --- 回帰木とモデル木によるワインの品質の見積もり

データを準備します.

url <- "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
wine <- read_delim(url,
                   delim = ";",
                   col_names = TRUE,
                   col_types = "dddddddddddi"
                 ) %>%
  as_tibble(.name_repair = "universal")
## New names:
## * `fixed acidity` -> fixed.acidity
## * `volatile acidity` -> volatile.acidity
## * `citric acid` -> citric.acid
## * `residual sugar` -> residual.sugar
## * `free sulfur dioxide` -> free.sulfur.dioxide
## * ...
str(wine)
## tibble [4,898 x 12] (S3: tbl_df/tbl/data.frame)
##  $ fixed.acidity       : num [1:4898] 7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
##  $ volatile.acidity    : num [1:4898] 0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
##  $ citric.acid         : num [1:4898] 0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
##  $ residual.sugar      : num [1:4898] 20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
##  $ chlorides           : num [1:4898] 0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
##  $ free.sulfur.dioxide : num [1:4898] 45 14 30 47 47 30 30 45 14 28 ...
##  $ total.sulfur.dioxide: num [1:4898] 170 132 97 186 186 97 136 170 132 129 ...
##  $ density             : num [1:4898] 1.001 0.994 0.995 0.996 0.996 ...
##  $ pH                  : num [1:4898] 3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
##  $ sulphates           : num [1:4898] 0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
##  $ alcohol             : num [1:4898] 8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
##  $ quality             : int [1:4898] 6 6 6 6 6 6 6 6 6 6 ...

ヒストグラムを用いてqualityの分布を確認してみます.

g <- ggplot(wine, aes(quality)) + 
  geom_histogram(binwidth = 1, colour = "black", fill = "white") +
  xlab("Quality") + ylab("Count")
print(g)

download.png

回帰木モデルの訓練から行います.学習器はregr.rpartを使います.

set.seed(0)
learner <- lrn("regr.rpart")
task <- as_task_regr(wine, target = "quality")
resampling <- rsmp("holdout", ratio = 0.75)
rr <- resample(task, learner, resampling, store_models = TRUE)
## INFO  [10:03:42.066] [mlr3]  Applying learner 'regr.rpart' on task 'wine' (iter 1/1)

訓練結果を取りだします.

m.rpart <- rr$learners[[1]]$model
print(m.rpart)
## n= 3674 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 3674 2923.10100 5.878879  
##    2) alcohol< 10.85 2298 1408.37900 5.606614  
##      4) volatile.acidity>=0.2525 1223  613.98040 5.360589  
##        8) free.sulfur.dioxide< 17.5 179   92.94972 4.983240 *
##        9) free.sulfur.dioxide>=17.5 1044  491.17240 5.425287 *
##      5) volatile.acidity< 0.2525 1075  636.15440 5.886512 *
##    3) alcohol>=10.85 1376 1059.88900 6.333576  
##      6) free.sulfur.dioxide< 11.5 86  106.33720 5.383721 *
##      7) free.sulfur.dioxide>=11.5 1290  870.78760 6.396899  
##       14) alcohol< 11.74167 641  430.43990 6.199688 *
##       15) alcohol>=11.74167 649  390.79510 6.591680 *

summary(m.rpart)も見ることができますが,長くなるので省略します.

決定木の理解のためにrpart.plotパッケージ1を使って可視化します.

rpart.plot(m.rpart, digits = 3)

download.png

パラメータを変えてみます.

rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE, type = 3, extra = 101)

download.png

予測値と実測値の要約統計量を確認します.

p.rpart <- rr$predictions()[[1]]
print(summary(p.rpart$response))
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   4.983   5.425   5.887   5.887   6.200   6.592
print(summary(p.rpart$truth))
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   3.000   5.000   6.000   5.875   6.000   9.000

予測値と実測値の相関係数を確認します.予測値を訓練結果からp.rpartに取り出します.

print(cor(p.rpart$response, p.rpart$truth))
## [1] 0.524354

平均絶対誤差(Mean Absolute Error: MAE)による性能の測定をします.

p.rpart$score(msr("regr.mae"))
##  regr.mae 
## 0.5837126

次にモデル木を用いてみます.mlr3にはM5'アルゴリズムの学習器はないようですので,M5'の発展であるCubistアルゴリズムによる学習器であるregr.cubistを使います.

learner <- lrn("regr.cubist")
rr <- resample(task, learner, resampling, store_models = TRUE)
## INFO  [10:03:42.843] [mlr3]  Applying learner 'regr.cubist' on task 'wine' (iter 1/1)

訓練結果をm.cubistとして取り出します.

m.cubist <- rr$learners[[1]]$model
print(m.cubist)
## 
## Call:
## cubist.default(x = x, y = y, committees =
##  self$param_set$values$committees, control = control, weights = if
##  ("weights" %in% task$properties) task$weights$weight else NULL)
## 
## Number of samples: 3674 
## Number of predictors: 11 
## 
## Number of committees: 1 
## Number of rules: 9
print(summary(m.cubist))
## 
## Call:
## cubist.default(x = x, y = y, committees =
##  self$param_set$values$committees, control = control, weights = if
##  ("weights" %in% task$properties) task$weights$weight else NULL)
## 
## 
## Cubist [Release 2.07 GPL Edition]  Thu Sep 02 10:03:42 2021
## ---------------------------------
## 
##     Target attribute `outcome'
## 
## Read 3674 cases (12 attributes) from undefined.data
## 
## Model:
## 
##   Rule 1: [164 cases, mean 5.2, range 3 to 7, est err 0.4]
## 
##     if
##  alcohol <= 10
##  citric.acid <= 0.26
##  free.sulfur.dioxide > 43.5
##  volatile.acidity > 0.205
##     then
##  outcome = 67.2 + 0.308 alcohol - 62 density - 0.93 pH
##            + 0.028 residual.sugar - 1.04 sulphates
##            - 0.41 volatile.acidity + 0.26 citric.acid
## 
##   Rule 2: [258 cases, mean 5.2, range 3 to 7, est err 0.5]
## 
##     if
##  alcohol <= 10
##  citric.acid <= 0.26
##  free.sulfur.dioxide <= 43.5
##  volatile.acidity > 0.205
##     then
##  outcome = 244.8 - 243 density + 0.091 residual.sugar
##            + 0.019 free.sulfur.dioxide + 1.68 citric.acid
##            - 0.34 volatile.acidity + 0.18 pH + 0.02 fixed.acidity
##            + 0.013 alcohol + 0.09 sulphates
## 
##   Rule 3: [45 cases, mean 5.3, range 5 to 6, est err 0.3]
## 
##     if
##  citric.acid > 0.26
##  residual.sugar > 17.9
##  volatile.acidity > 0.205
##     then
##  outcome = 4.8 - 17 chlorides + 0.058 residual.sugar
## 
##   Rule 4: [171 cases, mean 5.3, range 3 to 6, est err 0.5]
## 
##     if
##  alcohol <= 10
##  citric.acid > 0.26
##  density <= 0.99681
##  total.sulfur.dioxide > 173
##  volatile.acidity > 0.205
##     then
##  outcome = 216.7 - 216 density + 0.108 residual.sugar
##            + 0.005 total.sulfur.dioxide + 0.21 fixed.acidity
##            - 1.74 volatile.acidity + 0.0067 free.sulfur.dioxide
##            + 0.98 sulphates + 2.7 chlorides
## 
##   Rule 5: [230 cases, mean 5.5, range 4 to 8, est err 0.5]
## 
##     if
##  alcohol <= 10
##  citric.acid > 0.26
##  density <= 0.99681
##  total.sulfur.dioxide <= 173
##  volatile.acidity > 0.205
##     then
##  outcome = 103.3 - 99 density - 2.64 volatile.acidity + 0.163 alcohol
##            + 0.0022 total.sulfur.dioxide - 0.06 fixed.acidity
##            + 0.006 residual.sugar + 0.0015 free.sulfur.dioxide
##            - 0.12 citric.acid
## 
##   Rule 6: [395 cases, mean 5.6, range 3 to 9, est err 0.5]
## 
##     if
##  citric.acid > 0.26
##  density > 0.99681
##  residual.sugar <= 17.9
##  volatile.acidity > 0.205
##     then
##  outcome = 6.8 - 1.72 citric.acid - 1.74 volatile.acidity - 4.1 chlorides
##            + 0.63 sulphates
## 
##   Rule 7: [536 cases, mean 5.8, range 3 to 8, est err 0.6]
## 
##     if
##  alcohol > 10
##  free.sulfur.dioxide <= 21
##     then
##  outcome = 278.3 - 279 density + 0.046 free.sulfur.dioxide
##            + 0.126 residual.sugar - 2.09 volatile.acidity + 1.1 pH
##            + 0.03 fixed.acidity + 0.11 sulphates
## 
##   Rule 8: [851 cases, mean 6.2, range 4 to 8, est err 0.5]
## 
##     if
##  volatile.acidity <= 0.205
##     then
##  outcome = 54.5 - 50 density + 0.023 residual.sugar
##            - 0.42 volatile.acidity + 0.2 pH + 0.024 alcohol
##            + 0.03 fixed.acidity + 0.17 sulphates
##            + 0.001 free.sulfur.dioxide
## 
##   Rule 9: [1614 cases, mean 6.3, range 3 to 9, est err 0.6]
## 
##     if
##  alcohol > 10
##  free.sulfur.dioxide > 21
##     then
##  outcome = 259.2 - 262 density + 0.105 residual.sugar + 1.28 pH
##            + 0.2 fixed.acidity + 0.98 sulphates + 0.061 alcohol
##            - 0.39 volatile.acidity + 0.0013 free.sulfur.dioxide
## 
## 
## Evaluation on training data (3674 cases):
## 
##     Average  |error|                0.5
##     Relative |error|               0.69
##     Correlation coefficient        0.63
## 
## 
##  Attribute usage:
##    Conds  Model
## 
##     70%    73%    alcohol
##     60%    86%    free.sulfur.dioxide
##     50%    99%    volatile.acidity
##     30%    25%    citric.acid
##     19%    90%    density
##     10%    91%    residual.sugar
##      9%     9%    total.sulfur.dioxide
##            94%    sulphates
##            86%    fixed.acidity
##            80%    pH
##            14%    chlorides
## 
## 
## Time: 0.2 secs

予測値と実測値の比較を行います.

p.cubist <- rr$predictions()[[1]]
print(summary(p.cubist$response))
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   3.500   5.484   5.951   5.849   6.177   7.356
print(summary(p.cubist$truth))
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    3.00    5.00    6.00    5.85    6.00    9.00
print(cor(p.cubist$response, p.cubist$truth))
## [1] 0.5595234

最後にMAEで性能の測定をします.regr.rpartと比較して改善があったことがわかります.

p.cubist$score(msr("regr.mae"))
##  regr.mae 
## 0.5599107

参考資料

[Dalal, S. R., E. B. Fowlkes, and B. Hoadley (1989) "Risk Analysis of the Space Shuttle: Pre Challenger Prediction of Failure," Journal of the American Statistical Association, Vol. 84, No. 408, pp. 945-957] (https://www.researchgate.net/publication/238670560_Risk_Analysis_of_the_Space_Shuttle_Pre-Challenger_Prediction_of_Failure)

  1. 本当はggpartyパッケージを使いたいのですが….

2
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?