LoginSignup
19

More than 5 years have passed since last update.

caret vignette "A Short Introduction to the caret Package"の和訳

Last updated at Posted at 2016-03-12

巻頭言

CRAN caret のvignette https://cran.r-project.org/web/packages/caret/vignettes/caret.pdf の和訳です。

2018/03/18 vignette Octover 28, 2016版に対応しました。直訳で分かりにくいところ更に訳し直しました。


caretパッケージには、複雑な回帰と分類の学習を簡易に実施できる関数群があります。多数のRパッケージから構成されていますが、最初に全てを読み込む必要はありません1。初期状態では27のパッケージが利用できます。それ以外のパッケージについては、caretでは必要となった時にそれを読み込みます。

caretのインストールは

install.packages("caret", dependencies = c("Depends", "Suggests"))

となっており、必要な物をインストールします。
ヘルプページは:
http://caret.r-forge.r-project.org/
です。

このvignetteには多くの例と解説があります。

caretにはモデルの構築と評価について簡易に実施出来る関数が複数あります。また、特徴量の選択や他の処理に関する関数もあります。

このパッケージの機能で主なものは、

  • リサンプリングによる、パラメータチューニングの効果の評価
  • それらパラメータによる「最適」モデルの選択
  • トレーニングセットによるモデルパフォーマンスの評価 です。

caretパッケージの書式は:

1.  評価するためのモデルのパラメータの設定
2.  for each パラメータ do
3.      for each リサンプリング do 
4.          特定のサンプルの取り出し
5.          [オプション] データの前処理
6.          モデルの調整
7.          取り出したサンプルでの予測
8.      end
9.      得られた複数の予測についての平均性能の計算
10. end
11. 最適パラメータの決定
12. 最適なパラメータセットを用いてモデルをトレーニングデータに合致させる

このプロセスの各々のステップにおいてカスタマイズが出来ます(例えば、リサンプリングの方法、最適パラメータの選択、など)。この関数をデモするためにmlbenchパッケージのSonarデータを用います。

Sonarデータには60の項目について208のデータがあります。ゴールは2つのクラス(金属シリンダーを示すMと岩を示すR)を判別することです。

最初に、データを2つのグループに分けます:トレーニングセットとテストセットです。これにはcreateDataPartition関数を用います:

library(caret)
library(mlbench)
data(Sonar)
set.seed(107)
inTrain <- createDataPartition(y = Sonar$Class,
+ ## the outcome data are needed
+ p = .75,
+ ## The percentage of data in the
+ ## training set
+ list = FALSE)
## The format of the results

## The output is a set of integers for the rows of Sonar
## that belong in the training set.
str(inTrain)

int [1:157, 1] 1 2 3 6 7 9 10 11 12 13 ...
- attr(*, "dimnames")=List of 2
 ..$ : NULL
 ..$ : chr "Resample1"

デフォルトでは、createDataPartition関数はデータをランダムに分けます。データを分けるために下記を実行します:

training <- Sonar[ inTrain,]
testing <- Sonar[-inTrain,]
nrow(training)
[1] 157
nrow(testing)
[1] 51

アルゴリズム??にてモデルをチューニングするために、train関数を用います。この関数の詳細は

http://caret.r-forge.r-project.org/training.html
ここでは、PLS判別分析(a partial least squares discriminant analysis (PLSDA))モデルを用いてチューニングします。基本的な書式は次の通りです:

plsFit <- train(Class ~ .,
+ data = training,
+ method = "pls",
+ ## Center and scale the predictors for the training
+ ## set and all future samples.
+ preProc = c("center", "scale"))

しかし、いくつかカスタマイズすべき点があります:

  • PLSモデルを拡張して評価すること。デフォルトでは、3つ以上のパラメータをチューニングします。
  • リサンプリングの方法。デフォルトでは簡単なブートストラップを用いています。ここでは10-foldクロスバリデーションを3回繰り返します。
  • 性能評価の方法。指定がなければ、Κ(Kappa)統計量を計算します。回帰モデルでは、RMSEとR^2を計算します。ここでは、ROC曲線の下面積を評価し、感度と特異度を見ます。

チューニングパラメータの値を変更するために、tuneLengthとtuneGridを調整出来ます。train関数はパラメータ値の候補を設定しますが、tuneLengthによりどれくらいの範囲で評価するかを決めます。PDSの場合には、1からtuneLengthの値までの整数値を用います。もし1から15の整数値を評価したい場合には、tuneLength = 15とします。tuneGridは、特定の値を設定したいときに用います。各行にチューニングの属性、各列にチューニングの値としたデータフレームを指定します。
書式は次のようになります:

plsFit <- train(Class ~ .,
+ data = training,
+ method = "pls",
+ tuneLength = 15,
+ preProc = c("center", "scale"))

リサンプリングを決めるためにtrainControl関数を用います。オプションにより、リサンプリングの形式を指定しますがデフォルトは"boot"です。別の方法"repeatedcv"では、k回のクロスバリデーション(そして、繰り返し回数を指定します)を行います。Kのデフォルトは10です。書式は次の通りです:

ctrl <- trainControl(method = "repeatedcv",
+ repeats = 3)
plsFit <- train(Class ~ .,
+ data = training,
+ method = "pls",
+ tuneLength = 15,
+ trControl = ctrl,
+ preProc = c("center", "scale"))

最後に、性能計測の方法を指定するtrainControlを付け加えます。
summaryFunction属性は 性能の評価に関する値を観測かつ推定するためのものです。2つの関数がパッケージには含まれています:defaultSummary と twoClassSummaryです。後者は、2値判別問題で使われます。すなわちROC曲線の下面積による感度と特異性についてのものです。ROC曲線はクラス予測確率(自動には計算されない値)に基いて計算されるのでオプションが必要です。classProbs = TRUE オプションを指定します。
最後に、最高の結果になるようパラメータをチューニングします。ここではカスタムな性能計測指標を使ったので、最適化のための閾値を指定しなければなりません。
trainするときにmetric = "ROC"と指定する必要があります。
最終的にモデルは次のようになります:

set.seed(123)
ctrl <- trainControl(method = "repeatedcv",
+ repeats = 3,
+ classProbs = TRUE,
+ summaryFunction = twoClassSummary)
plsFit <- train(Class ~ .,
+ data = training,
+ method = "pls",
+ tuneLength = 15,
+ trControl = ctrl,
+ metric = "ROC",
+ preProc = c("center", "scale"))

plsFit

Partial Least Squares
The caret Package
157 samples
 60 predictor
 2 classes: 'M', 'R'

Pre-processing: centered, scaled
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 142, 141, 141, 141, 142, 142, ...
Resampling results across tuning parameters:

ncomp ROC Sens Spec ROC SD Sens SD Spec SD
 1 0.8111772 0.7120370 0.7172619 0.1177583 0.1573858 0.1850478
 2 0.8713211 0.7782407 0.8130952 0.1066032 0.1630141 0.1539514
 3 0.8660962 0.7699074 0.8339286 0.1069895 0.1548614 0.1184648
 4 0.8660136 0.7740741 0.7714286 0.1028392 0.1704134 0.1448749
 5 0.8504216 0.7532407 0.7845238 0.1049264 0.1685590 0.1747078
 6 0.8352679 0.7574074 0.8035714 0.1090738 0.1520460 0.1671863
 7 0.8093419 0.7296296 0.7791667 0.1270386 0.1510365 0.1687237
 8 0.8080688 0.7259259 0.7744048 0.1403985 0.1740438 0.1766201
 9 0.8122933 0.7291667 0.7517857 0.1433122 0.1410158 0.1753622
10 0.8182870 0.7296296 0.7654762 0.1378789 0.1474260 0.1781851
11 0.8303241 0.7416667 0.7702381 0.1161916 0.1440423 0.1825742
12 0.8203869 0.7458333 0.7738095 0.1339745 0.1532842 0.1802467
13 0.8223958 0.7337963 0.7744048 0.1325278 0.1566296 0.1756838
14 0.8165840 0.7375000 0.7833333 0.1398943 0.1630494 0.1695202
15 0.8090939 0.7337963 0.7744048 0.1435415 0.1633646 0.1756838

ROC was used to select the optimal model using the largest value.
The final value used for the model was ncomp = 2.

この出力では、リサンプルした平均で性能を評価しています。表の最終行のPLS成分が最適値を示しています。この値に基づいて、PLSモデルがデータセットに適しており予測に用いられることとなります。

このパッケージには結果の可視化のためのいくつかの関数があります。1つの方法はtrainオブジェクトをプロットするものです。
plot(plsFit) コマンドの結果を図1に示します。この図はPLS成分の数とリサンプル評価との間の関係を示しています。

(略)
図1: plot(plsFit) は、PLS成分の数とROC曲線の下面積によるリサンプル評価との関係を示します。

新しいサンプルで予測を行う場合には、 predict.train を用います。分類のモデルにおいては、デフォルトの動作ではクラスの予測を出力します。 type = "prob" オプションを指定することにより、クラスに配分される確率を出力します。例:

plsClasses <- predict(plsFit, newdata = testing)
head(plsProbs)

           M         R
4  0.3762529 0.6237471
5  0.5229047 0.4770953
8  0.5839468 0.4160532
16 0.3660142 0.6339858
20 0.7351013 0.2648987
25 0.2135788 0.7864212

caret には、混同行列(confusion matrix)とモデルフィットに関連する統計値を計算する関数もあります:

confusionMatrix(data = plsClasses, testing$Class)

$positive
[1] "M"

$table
          Reference
Prediction  M  R
         M 20  7
         R  7 17

$overall
      Accuracy          Kappa   AccuracyLower  AccuracyUpper  AccuracyNull
   0.725490196    0.449074074     0.582552477       
  0.841072735   0.529411765
AccuracyPValue  McnemarPValue
   0.003346986    1.000000000

$byClass
          Sensitivity        Specificity   Pos Pred Value
            0.7407407          0.7083333        0.7407407
       Neg Pred Value          Precision           Recall
            0.7083333          0.7407407        0.7407407
                   F1         Prevalence  Detection Rate
            0.7407407          0.5294118       0.3921569
 Detection Prevalence  Balanced Accuracy
            0.5294118          0.7245370

$mode
[1] "sens_spec"

$dots
list()

attr(,"class")
[1] "confusionMatrix"

別のモデルを使用したい場合でも、train関数では最小限の変更で済みます。使えるモデルのリストは以下にあります:
 http://caret.r-forge.r-project.org/modelList.html
 http://caret.r-forge.r-project.org/bytag.html

例えば、このデータにregularized discriminantモデルをフィットさせたい場合には:

## To illustrate, a custom grid is used
rdaGrid = data.frame(gamma = (0:4)/4, lambda = 3/4)
set.seed(123)
rdaFit <- train(Class ~ .,
+ data = training,
+ method = "rda",
+ tuneGrid = rdaGrid,
+ trControl = ctrl,
+ metric = "ROC")
rdaFit

Regularized Discriminant Analysis

157 samples
 60 predictor
  2 classes: 'M', 'R'

No pre-processing
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 142, 141, 141, 141, 142, 142, ...
Resampling results across tuning parameters:

  gamma ROC      Sens      Spec
  0.00 0.8448826 0.7884259 0.7625000
  0.25 0.8860119 0.8060185 0.8035714
  0.50 0.8851190 0.8097222 0.7666667
  0.75 0.8685847 0.7745370 0.7529762
  1.00 0.7563823 0.6615741 0.6803571

Tuning parameter 'lambda' was held constant at a value of 0.75
ROC was used to select the optimal model using the largest value.
The final values used for the model were gamma = 0.25 and lambda = 0.75.

rdaClasses <- predict(rdaFit, newdata = testing)
confusionMatrix(rdaClasses, testing$Class)

$positive
[1] "M"

$table
           Reference
Prediction  M  R
         M 22  5
         R  5 19

$overall
      Accuracy         Kappa  AccuracyLower AccuracyUpper  AccuracyNull
  0.8039215686  0.6064814815   0.6688426487  
 0.9017565835  0.5294117647
AccuracyPValue  McnemarPValue
0.0000434105     1.0000000000

$byClass
         Sensitivity        Specificity    Pos Pred Value
           0.8148148          0.7916667         0.8148148
      Neg Pred Value          Precision            Recall
           0.7916667          0.8148148         0.8148148
                  F1         Prevalence    Detection Rate
           0.8148148          0.5294118         0.4313725
Detection Prevalence  Balanced Accuracy
           0.5294118          0.8032407

$mode
[1] "sens_spec"

$dots
list()

attr(,"class")
[1] "confusionMatrix"

リサンプリング結果からモデルの比較をどのように行うのか?リサンプルの関数は、収集、要約とリサンプル結果の対比に用います。乱数のシーズ(seeds)値は、trainを呼び出す前に同じ値に初期化されるため、それぞれのモデルの間で乱数は同じ値となります。
組み合わせると:

resamps <- resamples(list(pls = plsFit, rda = rdaFit))
summary(resamps)

Call:
summary.resamples(object = resamps)
Models: pls, rda
Number of resamples: 30

ROC
      Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
pls 0.5397  0.8333 0.8672 0.8713  0.9509    1    0
rda 0.6508  0.8214 0.8750 0.8769  0.9509    1    0

Sens
      Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
pls 0.3333    0.75 0.7778 0.7782  0.8750    1    0
rda 0.4444    0.75 0.8750 0.8144  0.8889    1    0

Spec
      Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
pls 0.5000  0.7143 0.8571 0.8131  0.9688    1    0
rda 0.2857  0.7143 0.7143 0.7655  0.8571    1    0

結果の可視化のための関数がいくつかあります。例えば、 Bland-Altman type プロットは xyplot(resamps, what = "BlandAltman") で呼び出せます(図2参照)。結果は似ているように見えます。各リサンプルについて対となる結果があるので、ペアのt検定が使え、ROC曲線下部の面積の平均値の差の評価が可能です。diff.resamples 関数で次のように実施します:

diffs <- diff(resamps)
summary(diffs)

Call:
summary.diff.resamples(object = diffs)

p-value adjustment: bonferroni
Upper diagonal: estimates of the difference
Lower diagonal: p-value for H0: difference = 0

ROC
    pls     rda
pls         -0.01469
rda 0.2975

Sens
    pls    rda
pls        -0.02778
rda 0.125

Spec
    pls     rda
pls         0.009524
rda 0.7348

(略)
図2:xyplot(resamps, what = "BlandAltman")の出力であるリサンプルROC値の Bland–Altman プロット。

この分析に基づいて、モデル間の差は -0.015 ROC units(RDAモデルがやや高い)であり、両側検定のp値は 0.29749 となります。


  1. 必要なものを後から追加することにより、パッケージの起動を大幅に早くしています。 

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19