LoginSignup
1
1

More than 3 years have passed since last update.

tidyverseで機械学習用にデータセットを分割する

Posted at

機械学習に利用するデータセットを、教師用データと検証用データに分割する。
irisデータの80%を教師データ、20%を検証用として分割し、Random forestによる機械学習を行ってみる。

irisデータセットに連番をふる。

library("tidyverse")
df <- iris %>%
  dplyr::as_tibble() %>% 
  dplyr::mutate(id=dplyr::row_number())

df
# A tibble: 150 x 6
   Sepal.Length Sepal.Width Petal.Length Petal.Width Species    id
          <dbl>       <dbl>        <dbl>       <dbl> <fct>   <int>
 1          5.1         3.5          1.4         0.2 setosa      1
 2          4.9         3            1.4         0.2 setosa      2
 3          4.7         3.2          1.3         0.2 setosa      3
 4          4.6         3.1          1.5         0.2 setosa      4
 5          5           3.6          1.4         0.2 setosa      5
 6          5.4         3.9          1.7         0.4 setosa      6
 7          4.6         3.4          1.4         0.3 setosa      7
 8          5           3.4          1.5         0.2 setosa      8
 9          4.4         2.9          1.4         0.2 setosa      9
10          4.9         3.1          1.5         0.1 setosa     10
# … with 140 more rows

各Speciesから80%を教師データとして抽出する。
df.trainを作成する際、dplyr::group_by(Species)でグループ化することでdplyr::sample_frac(0.8)が各グループで実行され、層化抽出となる。

df.train <- df %>% 
  dplyr::group_by(Species) %>%
  dplyr::sample_frac(0.8) %>% 
  dplyr::ungroup()

df.train %>% 
  dplyr::arrange(id)
# A tibble: 120 x 6
   Sepal.Length Sepal.Width Petal.Length Petal.Width Species    id
          <dbl>       <dbl>        <dbl>       <dbl> <fct>   <int>
 1          5.1         3.5          1.4         0.2 setosa      1
 2          4.9         3            1.4         0.2 setosa      2
 3          4.7         3.2          1.3         0.2 setosa      3
 4          5.4         3.9          1.7         0.4 setosa      6
 5          4.6         3.4          1.4         0.3 setosa      7
 6          5           3.4          1.5         0.2 setosa      8
 7          4.9         3.1          1.5         0.1 setosa     10
 8          5.4         3.7          1.5         0.2 setosa     11
 9          4.8         3            1.4         0.1 setosa     13
10          4.3         3            1.1         0.1 setosa     14
# … with 110 more rows

教師データとして抽出しなかった行を抽出してdf.validとする。

df.valid <- df %>% 
  dplyr::anti_join(df.train, by="id")

df.valid
# A tibble: 30 x 6
   Sepal.Length Sepal.Width Petal.Length Petal.Width Species    id
          <dbl>       <dbl>        <dbl>       <dbl> <fct>   <int>
 1          4.6         3.1          1.5         0.2 setosa      4
 2          5           3.6          1.4         0.2 setosa      5
 3          4.4         2.9          1.4         0.2 setosa      9
 4          4.8         3.4          1.6         0.2 setosa     12
 5          4.6         3.6          1           0.2 setosa     23
 6          5.2         3.5          1.5         0.2 setosa     28
 7          4.8         3.1          1.6         0.2 setosa     31
 8          5.5         4.2          1.4         0.2 setosa     34
 9          4.9         3.6          1.4         0.1 setosa     38
10          5           3.5          1.6         0.6 setosa     44
# … with 20 more rows

Random Forestを実行し、モデルを作成する。

rf.model <- randomForest::randomForest(Species ~ ., df.train)
rf.model
Call:
 randomForest(formula = Species ~ ., data = df.train) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 2

        OOB estimate of  error rate: 0.83%
Confusion matrix:
           setosa versicolor virginica class.error
setosa         40          0         0       0.000
versicolor      0         40         0       0.000
virginica       0          1        39       0.025

Out of Bagエラーが0.83%の精度のモデルができた。このモデルをvalidation用データセットに適用して精度を検証する。

predict <- stats::predict(rf.model, df.valid)
(result <- table(predict, df.valid$Species))
predict      setosa versicolor virginica
  setosa         10          0         0
  versicolor      0         10         0
  virginica       0          0        10
(accuracy_prediction = sum(diag(result)) / sum(result))
[1] 1

精度100%となりました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1