準備
必要なパッケージをインストールしておきます.
install.packages("rpart.plot")
install.packages("GGally")
library(mlr3verse)
library(mlr3extralearners)
library(tidyverse)
library(lubridate)
library(rpart.plot)
library(GGally)
6.1 回帰とは何か
Dalal, Flowlkes and Hoadley (1989)から,スペースシャトルの発射時の気温とOリングの危機的事象の数のデータを入力します.
launch <- tribble(
~Flight, ~Date, ~distress, ~temperature, ~field_check_pressure,
"1", "04/12/81", 0, 66, 50,
"2", "11/12/81", 1, 70, 50,
"3", "03/22/82", 0, 69, 50,
"5", "11/11/82", 0, 68, 50,
"6", "04/04/83", 0, 67, 50,
"7", "06/18/83", 0, 72, 50,
"8", "08/30/83", 0, 73, 100,
"9", "11/28/83", 0, 70, 100,
"41-B", "02/03/84", 1, 57, 200,
"41-C", "04/06/84", 1, 63, 200,
"41-D", "08/30/84", 1, 70, 200,
"41-G", "10/05/84", 0, 78, 200,
"51-A", "11/08/84", 0, 67, 200,
"51-C", "01/24/85", 2, 53, 200,
"51-D", "04/12/85", 0, 67, 200,
"51-B", "04/29/85", 0, 75, 200,
"51-G", "06/17/85", 0, 70, 200,
"51-F", "07/29/85", 0, 81, 200,
"51-I", "07/27/85", 0, 76, 200,
"51-J", "10/03/85", 0, 79, 200,
"61-A", "10/30/85", 2, 75, 200,
"61-B", "11/26/85", 0, 76, 200,
"61-C", "01/12/86", 1, 58, 200
) %>%
mutate(Date = mdy(Date)) %>%
mutate(flight_num = row_number())
launch
## # A tibble: 23 x 6
## Flight Date distress temperature field_check_pressure flight_num
## <chr> <date> <dbl> <dbl> <dbl> <int>
## 1 1 1981-04-12 0 66 50 1
## 2 2 1981-11-12 1 70 50 2
## 3 3 1982-03-22 0 69 50 3
## 4 5 1982-11-11 0 68 50 4
## 5 6 1983-04-04 0 67 50 5
## 6 7 1983-06-18 0 72 50 6
## 7 8 1983-08-30 0 73 100 7
## 8 9 1983-11-28 0 70 100 8
## 9 41-B 1984-02-03 1 57 200 9
## 10 41-C 1984-04-06 1 63 200 10
## # ... with 13 more rows
散布図を描きます.
g <- ggplot(launch, aes(temperature, distress)) +
geom_point() +
xlab("Temperature") + ylab("Distress")
print(g)
回帰直線を描き込みます.
g <- g + stat_smooth(method = "lm", se = FALSE)
print(g)
回帰係数を推計します.
b <- cov(launch$temperature, launch$distress) / var(launch$temperature)
print(b)
## [1] -0.04753968
a <- mean(launch$distress) - b * mean(launch$temperature)
print(a)
## [1] 3.698413
発射時の気温とOリングの危機的事象の数の相関を計算します.
r <- cov(launch$temperature, launch$distress) /
(sd(launch$temperature) * sd(launch$distress))
print(r)
## [1] -0.5111264
独立変数と従属変数を引数として取り,回帰係数のベクトルを返す関数を定義します.
reg <- function(y, x) {
x <- as.matrix(x)
x <- cbind(intercept = 1, x)
b <- solve(t(x) %*% x) %*% t(x) %*% y
colnames(b) <- "estimate"
print(b)
}
この関数を使って気温とOリングの危機的事象の数の単純線形回帰係数を求めてみます.
reg(y = launch$distress, x = launch[4])
## estimate
## intercept 3.69841270
## temperature -0.04753968
多重線形回帰の回帰係数を計算します.
reg(y = launch$distress, x = launch[4:6])
## estimate
## intercept 3.527093383
## temperature -0.051385940
## field_check_pressure 0.001757009
## flight_num 0.014292843
Rに用意されている関数を使って回帰分析
せっかくですのでRに用意されている関数を用いて,もう少しきちんと回帰分析をしてみましょう.
cor()
を用いると説明変数間の関係を表す相関行列を表示することができます.
cor(launch[
c("distress",
"temperature",
"field_check_pressure",
"flight_num")
])
## distress temperature field_check_pressure flight_num
## distress 1.0000000 -0.51112639 0.28466627 0.1735779
## temperature -0.5111264 1.00000000 0.03981769 0.2307702
## field_check_pressure 0.2846663 0.03981769 1.00000000 0.8399324
## flight_num 0.1735779 0.23077017 0.83993237 1.0000000
GGally
パッケージに含まれるggpairs()
を用いると説明変数間の関係を図示することができます.
g_pairs <- launch %>%
select(distress, temperature, field_check_pressure, flight_num) %>%
ggpairs()
print(g_pairs)
回帰分析を行うためにはlm()
関数を用います.引数として~
の左側に被説明変数,右側に説明変数を書きます.
ここではまず単回帰を行ってみます.計算結果が上で求めたreg(y = launch$distress, x = launch[4])
の結果と一致していることがわかります.
distress_lm <- lm(distress ~ temperature, data = launch)
print(distress_lm)
##
## Call:
## lm(formula = distress ~ temperature, data = launch)
##
## Coefficients:
## (Intercept) temperature
## 3.69841 -0.04754
次に説明変数をすべて使った重回帰です.こちらの計算結果はreg(y = launch$distress, x = launch[4:6])
に対応します.
distress_lm_all <-
lm(distress ~ temperature + field_check_pressure + flight_num,
data = launch)
print(distress_lm_all)
##
## Call:
## lm(formula = distress ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Coefficients:
## (Intercept) temperature field_check_pressure
## 3.527093 -0.051386 0.001757
## flight_num
## 0.014293
結果の詳細を知るためにはsummary()
を使います.
print(summary(distress_lm_all))
##
## Call:
## lm(formula = distress ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.65003 -0.24414 -0.11219 0.01279 1.67530
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 3.527093 1.307024 2.699 0.0142 *
## temperature -0.051386 0.018341 -2.802 0.0114 *
## field_check_pressure 0.001757 0.003402 0.517 0.6115
## flight_num 0.014293 0.035138 0.407 0.6887
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.565 on 19 degrees of freedom
## Multiple R-squared: 0.36, Adjusted R-squared: 0.259
## F-statistic: 3.563 on 3 and 19 DF, p-value: 0.03371
かろうじて5%有意水準でtemperature
のみ帰無仮説を棄却することができることがわかります.
lmr3
で回帰分析
lmr3
パッケージにはlm()
を呼び出すための学習器であるregr.lm
がありますので,ついでにこれを使って回帰分析してみましょう.
learner_lm <- lrn("regr.lm")
task <- launch %>%
select(-Date) %>%
as_task_regr(target = "distress")
task$set_col_roles("Flight", roles = "name")
print(task)
## <TaskRegr:.> (23 x 4)
## * Target: distress
## * Properties: -
## * Features (3):
## - dbl (2): field_check_pressure, temperature
## - int (1): flight_num
実際にデータが含まれていることを確認します.
print(task$data())
## distress field_check_pressure flight_num temperature
## 1: 0 50 1 66
## 2: 1 50 2 70
## 3: 0 50 3 69
## 4: 0 50 4 68
## 5: 0 50 5 67
## 6: 0 50 6 72
## 7: 0 100 7 73
## 8: 0 100 8 70
## 9: 1 200 9 57
## 10: 1 200 10 63
## 11: 1 200 11 70
## 12: 0 200 12 78
## 13: 0 200 13 67
## 14: 2 200 14 53
## 15: 0 200 15 67
## 16: 0 200 16 75
## 17: 0 200 17 70
## 18: 0 200 18 81
## 19: 0 200 19 76
## 20: 0 200 20 79
## 21: 2 200 21 75
## 22: 0 200 22 76
## 23: 1 200 23 58
## distress field_check_pressure flight_num temperature
いったんタスクにしてしまえばペアプロットも簡単に描けます.
autoplot(task, type = "pairs")
リサンプリングは全部を訓練データにしてしまえばOKです.
resampling <- rsmp("holdout", ratio = 1)
rr <- resample(task, learner_lm, resampling, store_models = TRUE)
## INFO [10:03:40.607] [mlr3] Applying learner 'regr.lm' on task '.' (iter 1/1)
当たり前ですが計算結果は一致していますね.
print(rr$learners[[1]]$model)
##
## Call:
## stats::lm(formula = task$formula(), data = task$data())
##
## Coefficients:
## (Intercept) field_check_pressure flight_num
## 3.527093 0.001757 0.014293
## temperature
## -0.051386
6.2 実例 --- 線形回帰を使った医療費の予測
データがないため省略.
6.3 回帰木とモデル木とは何か
例を用いて標準偏差減少(Standard Deviation Reduction:SDR)について計算してみます.
tee <- c(1, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7)
at1 <- tee[1:9]
at2 <- tee[10:15]
bt1 <- tee[1:7]
bt2 <- tee[8:15]
sdr_a <-
sd(tee) -
(length(at1) / length(tee) * sd(at1) + length(at2) / length(tee) * sd(at2))
sdr_b <-
sd(tee) -
(length(bt1) / length(tee) * sd(bt1) + length(bt2) / length(tee) * sd(bt2))
print(paste0("sdr_a: ", sdr_a))
## [1] "sdr_a: 1.20281456809792"
print(paste0("sdr_b: ", sdr_b))
## [1] "sdr_b: 1.39275139353039"
bt*
による分割の方が標準偏差が減少していることがわかります.
6.4 実例 --- 回帰木とモデル木によるワインの品質の見積もり
データを準備します.
url <- "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
wine <- read_delim(url,
delim = ";",
col_names = TRUE,
col_types = "dddddddddddi"
) %>%
as_tibble(.name_repair = "universal")
## New names:
## * `fixed acidity` -> fixed.acidity
## * `volatile acidity` -> volatile.acidity
## * `citric acid` -> citric.acid
## * `residual sugar` -> residual.sugar
## * `free sulfur dioxide` -> free.sulfur.dioxide
## * ...
str(wine)
## tibble [4,898 x 12] (S3: tbl_df/tbl/data.frame)
## $ fixed.acidity : num [1:4898] 7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
## $ volatile.acidity : num [1:4898] 0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
## $ citric.acid : num [1:4898] 0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
## $ residual.sugar : num [1:4898] 20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
## $ chlorides : num [1:4898] 0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
## $ free.sulfur.dioxide : num [1:4898] 45 14 30 47 47 30 30 45 14 28 ...
## $ total.sulfur.dioxide: num [1:4898] 170 132 97 186 186 97 136 170 132 129 ...
## $ density : num [1:4898] 1.001 0.994 0.995 0.996 0.996 ...
## $ pH : num [1:4898] 3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
## $ sulphates : num [1:4898] 0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
## $ alcohol : num [1:4898] 8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
## $ quality : int [1:4898] 6 6 6 6 6 6 6 6 6 6 ...
ヒストグラムを用いてquality
の分布を確認してみます.
g <- ggplot(wine, aes(quality)) +
geom_histogram(binwidth = 1, colour = "black", fill = "white") +
xlab("Quality") + ylab("Count")
print(g)
回帰木モデルの訓練から行います.学習器はregr.rpart
を使います.
set.seed(0)
learner <- lrn("regr.rpart")
task <- as_task_regr(wine, target = "quality")
resampling <- rsmp("holdout", ratio = 0.75)
rr <- resample(task, learner, resampling, store_models = TRUE)
## INFO [10:03:42.066] [mlr3] Applying learner 'regr.rpart' on task 'wine' (iter 1/1)
訓練結果を取りだします.
m.rpart <- rr$learners[[1]]$model
print(m.rpart)
## n= 3674
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 3674 2923.10100 5.878879
## 2) alcohol< 10.85 2298 1408.37900 5.606614
## 4) volatile.acidity>=0.2525 1223 613.98040 5.360589
## 8) free.sulfur.dioxide< 17.5 179 92.94972 4.983240 *
## 9) free.sulfur.dioxide>=17.5 1044 491.17240 5.425287 *
## 5) volatile.acidity< 0.2525 1075 636.15440 5.886512 *
## 3) alcohol>=10.85 1376 1059.88900 6.333576
## 6) free.sulfur.dioxide< 11.5 86 106.33720 5.383721 *
## 7) free.sulfur.dioxide>=11.5 1290 870.78760 6.396899
## 14) alcohol< 11.74167 641 430.43990 6.199688 *
## 15) alcohol>=11.74167 649 390.79510 6.591680 *
summary(m.rpart)
も見ることができますが,長くなるので省略します.
決定木の理解のためにrpart.plot
パッケージ1を使って可視化します.
rpart.plot(m.rpart, digits = 3)
パラメータを変えてみます.
rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE, type = 3, extra = 101)
予測値と実測値の要約統計量を確認します.
p.rpart <- rr$predictions()[[1]]
print(summary(p.rpart$response))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.983 5.425 5.887 5.887 6.200 6.592
print(summary(p.rpart$truth))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.000 5.000 6.000 5.875 6.000 9.000
予測値と実測値の相関係数を確認します.予測値を訓練結果からp.rpart
に取り出します.
print(cor(p.rpart$response, p.rpart$truth))
## [1] 0.524354
平均絶対誤差(Mean Absolute Error: MAE)による性能の測定をします.
p.rpart$score(msr("regr.mae"))
## regr.mae
## 0.5837126
次にモデル木を用いてみます.mlr3
にはM5'アルゴリズムの学習器はないようですので,M5'の発展であるCubistアルゴリズムによる学習器であるregr.cubist
を使います.
learner <- lrn("regr.cubist")
rr <- resample(task, learner, resampling, store_models = TRUE)
## INFO [10:03:42.843] [mlr3] Applying learner 'regr.cubist' on task 'wine' (iter 1/1)
訓練結果をm.cubist
として取り出します.
m.cubist <- rr$learners[[1]]$model
print(m.cubist)
##
## Call:
## cubist.default(x = x, y = y, committees =
## self$param_set$values$committees, control = control, weights = if
## ("weights" %in% task$properties) task$weights$weight else NULL)
##
## Number of samples: 3674
## Number of predictors: 11
##
## Number of committees: 1
## Number of rules: 9
print(summary(m.cubist))
##
## Call:
## cubist.default(x = x, y = y, committees =
## self$param_set$values$committees, control = control, weights = if
## ("weights" %in% task$properties) task$weights$weight else NULL)
##
##
## Cubist [Release 2.07 GPL Edition] Thu Sep 02 10:03:42 2021
## ---------------------------------
##
## Target attribute `outcome'
##
## Read 3674 cases (12 attributes) from undefined.data
##
## Model:
##
## Rule 1: [164 cases, mean 5.2, range 3 to 7, est err 0.4]
##
## if
## alcohol <= 10
## citric.acid <= 0.26
## free.sulfur.dioxide > 43.5
## volatile.acidity > 0.205
## then
## outcome = 67.2 + 0.308 alcohol - 62 density - 0.93 pH
## + 0.028 residual.sugar - 1.04 sulphates
## - 0.41 volatile.acidity + 0.26 citric.acid
##
## Rule 2: [258 cases, mean 5.2, range 3 to 7, est err 0.5]
##
## if
## alcohol <= 10
## citric.acid <= 0.26
## free.sulfur.dioxide <= 43.5
## volatile.acidity > 0.205
## then
## outcome = 244.8 - 243 density + 0.091 residual.sugar
## + 0.019 free.sulfur.dioxide + 1.68 citric.acid
## - 0.34 volatile.acidity + 0.18 pH + 0.02 fixed.acidity
## + 0.013 alcohol + 0.09 sulphates
##
## Rule 3: [45 cases, mean 5.3, range 5 to 6, est err 0.3]
##
## if
## citric.acid > 0.26
## residual.sugar > 17.9
## volatile.acidity > 0.205
## then
## outcome = 4.8 - 17 chlorides + 0.058 residual.sugar
##
## Rule 4: [171 cases, mean 5.3, range 3 to 6, est err 0.5]
##
## if
## alcohol <= 10
## citric.acid > 0.26
## density <= 0.99681
## total.sulfur.dioxide > 173
## volatile.acidity > 0.205
## then
## outcome = 216.7 - 216 density + 0.108 residual.sugar
## + 0.005 total.sulfur.dioxide + 0.21 fixed.acidity
## - 1.74 volatile.acidity + 0.0067 free.sulfur.dioxide
## + 0.98 sulphates + 2.7 chlorides
##
## Rule 5: [230 cases, mean 5.5, range 4 to 8, est err 0.5]
##
## if
## alcohol <= 10
## citric.acid > 0.26
## density <= 0.99681
## total.sulfur.dioxide <= 173
## volatile.acidity > 0.205
## then
## outcome = 103.3 - 99 density - 2.64 volatile.acidity + 0.163 alcohol
## + 0.0022 total.sulfur.dioxide - 0.06 fixed.acidity
## + 0.006 residual.sugar + 0.0015 free.sulfur.dioxide
## - 0.12 citric.acid
##
## Rule 6: [395 cases, mean 5.6, range 3 to 9, est err 0.5]
##
## if
## citric.acid > 0.26
## density > 0.99681
## residual.sugar <= 17.9
## volatile.acidity > 0.205
## then
## outcome = 6.8 - 1.72 citric.acid - 1.74 volatile.acidity - 4.1 chlorides
## + 0.63 sulphates
##
## Rule 7: [536 cases, mean 5.8, range 3 to 8, est err 0.6]
##
## if
## alcohol > 10
## free.sulfur.dioxide <= 21
## then
## outcome = 278.3 - 279 density + 0.046 free.sulfur.dioxide
## + 0.126 residual.sugar - 2.09 volatile.acidity + 1.1 pH
## + 0.03 fixed.acidity + 0.11 sulphates
##
## Rule 8: [851 cases, mean 6.2, range 4 to 8, est err 0.5]
##
## if
## volatile.acidity <= 0.205
## then
## outcome = 54.5 - 50 density + 0.023 residual.sugar
## - 0.42 volatile.acidity + 0.2 pH + 0.024 alcohol
## + 0.03 fixed.acidity + 0.17 sulphates
## + 0.001 free.sulfur.dioxide
##
## Rule 9: [1614 cases, mean 6.3, range 3 to 9, est err 0.6]
##
## if
## alcohol > 10
## free.sulfur.dioxide > 21
## then
## outcome = 259.2 - 262 density + 0.105 residual.sugar + 1.28 pH
## + 0.2 fixed.acidity + 0.98 sulphates + 0.061 alcohol
## - 0.39 volatile.acidity + 0.0013 free.sulfur.dioxide
##
##
## Evaluation on training data (3674 cases):
##
## Average |error| 0.5
## Relative |error| 0.69
## Correlation coefficient 0.63
##
##
## Attribute usage:
## Conds Model
##
## 70% 73% alcohol
## 60% 86% free.sulfur.dioxide
## 50% 99% volatile.acidity
## 30% 25% citric.acid
## 19% 90% density
## 10% 91% residual.sugar
## 9% 9% total.sulfur.dioxide
## 94% sulphates
## 86% fixed.acidity
## 80% pH
## 14% chlorides
##
##
## Time: 0.2 secs
予測値と実測値の比較を行います.
p.cubist <- rr$predictions()[[1]]
print(summary(p.cubist$response))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.500 5.484 5.951 5.849 6.177 7.356
print(summary(p.cubist$truth))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.00 5.00 6.00 5.85 6.00 9.00
print(cor(p.cubist$response, p.cubist$truth))
## [1] 0.5595234
最後にMAEで性能の測定をします.regr.rpart
と比較して改善があったことがわかります.
p.cubist$score(msr("regr.mae"))
## regr.mae
## 0.5599107
参考資料
[Dalal, S. R., E. B. Fowlkes, and B. Hoadley (1989) "Risk Analysis of the Space Shuttle: Pre Challenger Prediction of Failure," Journal of the American Statistical Association, Vol. 84, No. 408, pp. 945-957] (https://www.researchgate.net/publication/238670560_Risk_Analysis_of_the_Space_Shuttle_Pre-Challenger_Prediction_of_Failure)
-
本当は
ggparty
パッケージを使いたいのですが…. ↩