機械学習に利用するデータセットを、教師用データと検証用データに分割する。
iris
データの80%を教師データ、20%を検証用として分割し、Random forestによる機械学習を行ってみる。
irisデータセットに連番をふる。
library("tidyverse")
df <- iris %>%
dplyr::as_tibble() %>%
dplyr::mutate(id=dplyr::row_number())
df
# A tibble: 150 x 6
Sepal.Length Sepal.Width Petal.Length Petal.Width Species id
<dbl> <dbl> <dbl> <dbl> <fct> <int>
1 5.1 3.5 1.4 0.2 setosa 1
2 4.9 3 1.4 0.2 setosa 2
3 4.7 3.2 1.3 0.2 setosa 3
4 4.6 3.1 1.5 0.2 setosa 4
5 5 3.6 1.4 0.2 setosa 5
6 5.4 3.9 1.7 0.4 setosa 6
7 4.6 3.4 1.4 0.3 setosa 7
8 5 3.4 1.5 0.2 setosa 8
9 4.4 2.9 1.4 0.2 setosa 9
10 4.9 3.1 1.5 0.1 setosa 10
# … with 140 more rows
各Speciesから80%を教師データとして抽出する。
df.train
を作成する際、dplyr::group_by(Species)
でグループ化することで dplyr::sample_frac(0.8)
が各グループで実行され、層化抽出となる。
df.train <- df %>%
dplyr::group_by(Species) %>%
dplyr::sample_frac(0.8) %>%
dplyr::ungroup()
df.train %>%
dplyr::arrange(id)
# A tibble: 120 x 6
Sepal.Length Sepal.Width Petal.Length Petal.Width Species id
<dbl> <dbl> <dbl> <dbl> <fct> <int>
1 5.1 3.5 1.4 0.2 setosa 1
2 4.9 3 1.4 0.2 setosa 2
3 4.7 3.2 1.3 0.2 setosa 3
4 5.4 3.9 1.7 0.4 setosa 6
5 4.6 3.4 1.4 0.3 setosa 7
6 5 3.4 1.5 0.2 setosa 8
7 4.9 3.1 1.5 0.1 setosa 10
8 5.4 3.7 1.5 0.2 setosa 11
9 4.8 3 1.4 0.1 setosa 13
10 4.3 3 1.1 0.1 setosa 14
# … with 110 more rows
教師データとして抽出しなかった行を抽出してdf.valid
とする。
df.valid <- df %>%
dplyr::anti_join(df.train, by="id")
df.valid
# A tibble: 30 x 6
Sepal.Length Sepal.Width Petal.Length Petal.Width Species id
<dbl> <dbl> <dbl> <dbl> <fct> <int>
1 4.6 3.1 1.5 0.2 setosa 4
2 5 3.6 1.4 0.2 setosa 5
3 4.4 2.9 1.4 0.2 setosa 9
4 4.8 3.4 1.6 0.2 setosa 12
5 4.6 3.6 1 0.2 setosa 23
6 5.2 3.5 1.5 0.2 setosa 28
7 4.8 3.1 1.6 0.2 setosa 31
8 5.5 4.2 1.4 0.2 setosa 34
9 4.9 3.6 1.4 0.1 setosa 38
10 5 3.5 1.6 0.6 setosa 44
# … with 20 more rows
Random Forestを実行し、モデルを作成する。
rf.model <- randomForest::randomForest(Species ~ ., df.train)
rf.model
Call:
randomForest(formula = Species ~ ., data = df.train)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 2
OOB estimate of error rate: 0.83%
Confusion matrix:
setosa versicolor virginica class.error
setosa 40 0 0 0.000
versicolor 0 40 0 0.000
virginica 0 1 39 0.025
Out of Bagエラーが0.83%の精度のモデルができた。このモデルをvalidation用データセットに適用して精度を検証する。
predict <- stats::predict(rf.model, df.valid)
(result <- table(predict, df.valid$Species))
predict setosa versicolor virginica
setosa 10 0 0
versicolor 0 10 0
virginica 0 0 10
(accuracy_prediction = sum(diag(result)) / sum(result))
[1] 1
精度100%となりました。