LoginSignup
1
1

More than 5 years have passed since last update.

caret: General Topics - Model Performanceの和訳

Posted at

訳注

http://topepo.github.io/caret/other.html の和訳です

テストセットの評価

postResample 関数は、回帰あるいは分類の train によるパフォーマンスを計測するために用います。

caret には分類モデルのパフォーマンスを計測するための関数がいくつかあります。感度、特異度を見る関数として、poPredValue と negPredValue があり、2分類のパフォーマンスを計測します。デフォルトでは、アウトカム要素の1番目は"positive" 結果(すなわち、興味のある事象)と定義されますが、変えることも出来ます。

confusionMatrix 関数は分類モデルの結果の要約をします。モデルのトレーニングとチューニングの例を示します:

testPred <- predict(gbmFit3, testing)
postResample(testPred, testing$Class)

 Accuracy     Kappa 
0.8627451 0.7226107 
sensitivity(testPred, testing$Class)

[1] 0.9259259

confusionMatrix(testPred, testing$Class)

Confusion Matrix and Statistics

          Reference
Prediction  M  R
         M 25  5
         R  2 19

               Accuracy : 0.8627         
                 95% CI : (0.7374, 0.943)
    No Information Rate : 0.5294         
    P-Value [Acc > NIR] : 5.008e-07      

                  Kappa : 0.7226         
 Mcnemar's Test P-Value : 0.4497         

            Sensitivity : 0.9259         
            Specificity : 0.7917         
         Pos Pred Value : 0.8333         
         Neg Pred Value : 0.9048         
             Prevalence : 0.5294         
         Detection Rate : 0.4902         
   Detection Prevalence : 0.5882         
      Balanced Accuracy : 0.8588         

       'Positive' Class : M              

"no--information rate" は、観測されたクラス(このテストセットでよりアクティブなもの)の最大の比率となります。最大のクラスの確率よりも大きいかどうか仮説検定を行います。"positie event" の prevalence がデータ(属性によらず)から計算されます。検出率(全事象に対する真陽性の率)と detection prevalence (予測された事象の prevalence)です。

2x2 表で考えると

Reference
Predicted Event No Event
Event A B
No Event C D

式で考えると:

\begin{align}
感度 Sensitivity &= \frac{A}{A+C} \\
特異度 Specificity &= \frac{D}{B+D} \\
Prevalence &= \frac{A+C}{A+B+C+D} \\
PPV &= \frac{sensitivity \times prevalence}{(sensitivity \times prevalence)+(1 - sensitivity)\times(1 - prevalence)} \\
NPV &= \frac{sensitivity \times (1 - prevalence)}{((1 - sensitivity) \times prevalence) + sensitivity \times (1 - prevalence)} \\
検出率 Detection Rate &= \frac{A}{A+B+C+D} \\
Detection Prevalence &= \frac{A+B}{A+B+C+D}
\end{align}

3つ以上のクラスがあるとき、confusionMatrix で混同行列を示し、「1対多」の結果の集合を示します。例えば、3クラス分類問題で、最初のクラスの感度は2番目と3番めのクラスに入るすべてのサンプルに対するものとして計算されます。

トレーニングセットのリサンプルの評価では confusionMatrix.train が使われます。各リサンプルの段階について、hold-out サンプルとそれらの値から混同行列が作成され、モデルのフィッティングに関しての診断材料として用いられます。

例えば:

confusionMatrix(gbmFit3)

Cross-Validated (10 fold, repeated 10 times) Confusion Matrix 

(entries are percentages of table totals)

          Reference
Prediction    M    R
         M 47.2 10.8
         R  6.3 35.7

リサンプルの間に混同行列に hold-out サンプルがあるパーセントをこれらの値は示しています。これらの値を正規化する方法はいくつかあります。詳細は ?confusionMatrix.train を参照のこと。

他クラス問題については、パフォーマンス計測のための追加の関数があります。クラス確率に基づく多項式対数尤度(小さいほうがよい)の負値を mnLogLoss は計算します。パラメータチューニングの最適化に用いられますが、他の計測(例、正確度あるいはROC曲線の下部面積)と矛盾する場合があり、特に他の計測が最も可能性のある値に近い場合に起こります。この関数は上述の他の関数の属性と似ています:

test_results <- predict(gbmFit3, testing, type = "prob")
test_results$obs <- testing$Class
head(test_results)

             M            R obs
1 1.982479e-05 9.999802e-01   R
2 8.774626e-07 9.999991e-01   R
3 1.789676e-11 1.000000e+00   R
4 6.024225e-04 9.993976e-01   R
5 9.999999e-01 1.179912e-07   R
6 9.995127e-01 4.872602e-04   R

mnLogLoss(test_results, lev = levels(test_results$obs))

  logLoss 
0.6746702 

加えて、multiClassSummary 関数は関連する計測値を計算します:

  • クラスの予測には全体の正確度とΚ統計値を用います
  • 多項式対数損失の負値(クラス確率が使用可ならば)
  • 感度、特異度、ROC曲線の下部面積等のような「一対他」統計量の平均値

例えば:

test_results$pred <- predict(gbmFit3, testing)
multiClassSummary(test_results, lev = levels(test_results$obs))

          logLoss               ROC          Accuracy             Kappa 
        0.6746702         0.9413580         0.8627451         0.7226107 
      Sensitivity       Specificity    Pos_Pred_Value    Neg_Pred_Value 
        0.9259259         0.7916667         0.8333333         0.9047619 
   Detection_Rate Balanced_Accuracy 
        0.4901961         0.8587963 

クラス確率の評価

2つのクラスを持つデータセットについてクラス確率予測の2つの関数が caret にはあります。

lift 関数は、ヒットのパーセントについての確率の閾値を評価するための関数です。(トレーニングセットではなく)予測の確率のセットと真のクラスのラベルをこの関数は要求します。例えば、twoClassSim 関数を用いて2クラスのサンプルをシミュレーション出来、トレーニングセットにモデルセットをフィット出来ます:

set.seed(2)
trainingSim <- twoClassSim(1000)
evalSim     <- twoClassSim(1000)
testingSim  <- twoClassSim(1000)

ctrl <- trainControl(method = "cv",
                     classProbs = TRUE,
                     summaryFunction = twoClassSummary)

set.seed(1045)
fdaModel <- train(Class ~ ., data = trainingSim,
                  method = "fda",
                  metric = "ROC",
                  tuneLength = 20,
                  trControl = ctrl)

set.seed(1045)
ldaModel <- train(Class ~ ., data = trainingSim,
                  method = "lda",
                  metric = "ROC",
                  trControl = ctrl)

set.seed(1045)
c5Model <- train(Class ~ ., data = trainingSim,
                 method = "C5.0",
                 metric = "ROC",
                 tuneLength = 10,
                 trControl = ctrl)

## A summary of the resampling results:
getTrainPerf(fdaModel)

  TrainROC TrainSens TrainSpec method
1 0.956762 0.9041751  0.848599    fda

getTrainPerf(ldaModel)

   TrainROC TrainSens TrainSpec method
1 0.9044245 0.8342761  0.791256    lda

getTrainPerf(c5Model)

   TrainROC TrainSens TrainSpec method
1 0.9495571 0.8676094 0.8531401   C5.0

これらのモデルから、評価セットの予測を得て、最初のクラスの確率を保存します:

evalResults <- data.frame(Class = evalSim$Class)
evalResults$FDA <- predict(fdaModel, evalSim, type = "prob")[,"Class1"]
evalResults$LDA <- predict(ldaModel, evalSim, type = "prob")[,"Class1"]
evalResults$C5.0 <- predict(c5Model, evalSim, type = "prob")[,"Class1"]
head(evalResults)

   Class        FDA       LDA      C5.0
1 Class1 0.99244077 0.8838205 0.8445830
2 Class1 0.99128497 0.7572450 0.8882418
3 Class1 0.82142101 0.8883830 0.5732098
4 Class2 0.04336463 0.0140480 0.1690251
5 Class1 0.77494981 0.9320695 0.4824400
6 Class2 0.11532541 0.0524154 0.3310495

lift 関数は、計算と関連するプロットの機能があり、lift 曲線(gain 曲線と呼ばれる)をプロットします。属性値は参照線を示します。(The value argument creates reference lines.)

trellis.par.set(caretTheme())
liftData <- lift(Class ~ FDA + LDA + C5.0, data = evalResults)
plot(liftData, values = 60, auto.key = list(columns = 3,
                                            lines = TRUE,
                                            points = FALSE))

(図略)

このヒットの60%を見ると、データの30%以上がサンプル(予測確率を並べた場合)となっていることがわかります。LDA モデルが他の2つのモデルよりも悪いこともわかります。

確率測定の別の関数もあります。 gbm パッケージ、rms パッケージなどです。これらのプロットが、予測確率の値の見積もりに使え、データの事象確率からなっています。(These plots can be used to assess whether the value of the probability prediction is consistent with the event rate in the data.) この関数についての書式は lift 関数ととても似ています:

trellis.par.set(caretTheme())
calData <- calibration(Class ~ FDA + LDA + C5.0,
                       data = evalResults,
                       cuts = 13)
plot(calData, type = "l", auto.key = list(columns = 3,
                                          lines = TRUE,
                                          points = FALSE))

図略

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1