はじめに
こちらの記事で基礎分析を行った、UCI 機械学習リポジトリのReal estate valuationデータについて回帰木による分析を行った。線形重回帰の結果はこちら。
データの読み込みとデータの分割
データを読み込み列名を変更、さらに train dataset と test dataset を分ける。
library(openxlsx)
df <- read.xlsx("./Real_estate_valuation_data_set.xlsx")
coln <- colnames(df) #列名を保存しておく
colnames(df) <- c("No","X1","X2","X3","X4","X5","X6","Y") #新しい名前を付けておく
library(rsample)
set.seed(1234) #値の再現を確認するときは乱数シードを設定
df_split <- initial_split(df, prop = 0.8) #train:test = 8:2 の割合で分ける
df_train <- training(df_split)
df_test <- testing(df_split)
回帰木分析
回帰木による分析をrpart
関数によって行う。複雑さのパラメータ cp と、分岐される枝の中に含まれる個体の数 minsplit はデフォルトの値を用いる。
library(rpart)
df_train.rp <- rpart(Y~X1+X2+X3+X4+X5+X6, df_train, minsplit = 20, cp = 0.01)
print(df_train.rp, digits = 3)
n= 332
node), split, n, deviance, yval
* denotes terminal node
1) root 332 62100 37.8
2) X3>=982 115 5510 24.4
4) X5< 25 32 426 17.2 *
5) X5>=25 83 2750 27.2
10) X5< 25 72 1170 25.9 *
11) X5>=25 11 606 36.0 *
3) X3< 982 217 25300 44.9
6) X2>=11.1 147 9660 41.0
12) X5< 25 11 764 29.7 *
13) X5>=25 136 7360 41.9
26) X3>=333 83 3280 39.8 *
27) X3< 333 53 3110 45.3
54) X2< 13.8 8 321 36.7 *
55) X2>=13.8 45 2100 46.8 *
7) X2< 11.1 70 8940 52.9
14) X5< 25 33 1510 47.9 *
15) X5>=25 37 5880 57.3
30) X3>=311 19 558 52.3 *
31) X3< 311 18 4310 62.7 *
木が形成され各ノードで判定が行われているのが分かる。このままだと分かりにくいので次節で可視化を行う。その前に、評価用としてRMSEを計算する関数を定義する。
RMSE = function(m, o) {
tmp <- sqrt(mean((m-o)^2))
return(tmp)
}
以上を用いてテストデータの RMSE を評価する。
df_test.y <- predict(df_train.rp, newdata = df_test)
df_test.rmse <- RMSE(df_test.y, df_test$Y)
df_test.rmse
[1] 8.882217
微量ではあるが線形重回帰分析結果よりも良い値が得られた。
回帰木の可視化
回帰木の可視化の方法をいくつか実装した。まずrpart.plot
パッケージを使う方法がある。これはシンプルに実装ができる。
library(rpart.plot)
rpart.plot(df_train.rp, box.palette="RdBu", shadow.col="gray", nn=TRUE)
分岐の様子が分かるが、ノード2とノード5の条件が被って見えたり結果が分かりにくい。そこでもう少し見やすいものが無いか探したところ、ggparty
パッケージを使用する方法があった(参考記事:R ggpartyパッケージを用いた決定木の可視化)。これを参考にこの結果を可視化してみる。
まず、as.party
関数でノードのデータを作る。
library(ggplot2)
library(ggparty)
prt <- as.party(df_train.rp)
prt
Model formula:
Y ~ X1 + X2 + X3 + X4 + X5 + X6
Fitted party:
[1] root
| [2] X3 >= 981.5777
| | [3] X5 < 24.95107: 17.181 (n = 32, err = 426.0)
| | [4] X5 >= 24.95107
| | | [5] X5 < 24.98363: 25.885 (n = 72, err = 1166.7)
| | | [6] X5 >= 24.98363: 36.027 (n = 11, err = 606.4)
| [7] X3 < 981.5777
| | [8] X2 >= 11.1
| | | [9] X5 < 24.96405: 29.691 (n = 11, err = 763.8)
| | | [10] X5 >= 24.96405
| | | | [11] X3 >= 332.80635: 39.811 (n = 83, err = 3276.4)
| | | | [12] X3 < 332.80635
| | | | | [13] X2 < 13.8: 36.737 (n = 8, err = 320.8)
| | | | | [14] X2 >= 13.8: 46.820 (n = 45, err = 2102.8)
| | [15] X2 < 11.1
| | | [16] X5 < 24.97425: 47.900 (n = 33, err = 1505.6)
| | | [17] X5 >= 24.97425
| | | | [18] X3 >= 311.48625: 52.253 (n = 19, err = 557.8)
| | | | [19] X3 < 311.48625: 62.717 (n = 18, err = 4308.0)
Number of inner nodes: 9
Number of terminal nodes: 10
このデータを使い、ggplot2
を使って可視化する。
g <- ggparty(prt, terminal_space = 0.5)
# 分岐の枝を作り、分岐の条件を表示
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 3)
# 結果の所に箱ひげ図をプロット
g <- g + geom_node_plot(
gglist = list(geom_boxplot(aes(x = "", y = Y)), theme_bw(base_size = 12)),
scales = "fixed",
id = "terminal",
shared_axis_labels = TRUE,
shared_legend = TRUE,
legend_separator = TRUE,
)
# 各ノードのラベルを作成
g <- g + geom_node_label(
aes(col = splitvar),
line_list = list(aes(label = paste("Node", id)),
aes(label = splitvar)),
line_gpar = list(list(
size = 10,
col = "black",
fontface = "bold"
),
list(size = 12)),
ids = "inner"
)
# 結果のノードのラベルを作成
g <- g + geom_node_label(
aes(label = paste0("Node ", id, ", N = ", nodesize)),
fontface = "bold",
ids = "terminal",
size = 3,
nudge_y = 0.01
)
# 凡例を消す
g <- g + theme(legend.position = "none")
plot(g)
こちらの図の方が条件も含め見やすくなった。ノードの条件としてX2(house age), X3(distance to the nearest MRT station), X5(latitude)の三つしか使われてないのが分かる。X1(transaction date)とX6(longitude)が影響しないのは重回帰分析の結果からもなんとなく予想がつくが、X4(number of convenience stores)が影響しないのは少し意外な結果だった。
おわりに
Real estate valuation の住宅価格を予測する回帰木の分析を行ってみた。
df_train.rp <- rpart(Y~X1+X2+X3+X4+X5+X6, df_train, minsplit = 10, cp = 0.001)
などとして無理やり枝の数を増やして分析してみたりしたが、それだと過学習がおきてテストデータのRMSEの値は悪くなった。どの値が最適なのか、パラメータサーチの方法は別に考えなければいけないと思う。