5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

リフト曲線を描いてモデルの精度を確かめる

Last updated at Posted at 2018-08-20

##リフト曲線の定義
リフト曲線は、機械学習の領域において、モデルの精度を確認する上で役立ちます。
「リフト曲線」の定義は様々ありますが、
私はいつも、書籍『戦略的データサイエンス入門 ビジネスに活かすコンセプトとテクニック』の定義に従って、
リフト値をプロットしたもの、としています。

モデルの何かしらのロジックでインスタンスにスコアが付与され、その降順で並べ替えてスコア上位でフィルタしたときに、
どの程度正しく分類されたインスタンスがあるか、
その割合がランダムな分類器と比べてどの程度高いかを示したのが、リフト値です。

##リフト値の簡単な例
分類問題において、正例と負例が1:1の割合で含まれるデータセットに対してモデルを作成し、
スコア上位25%のところまでのデータを見たとします。
ランダムな分類器だとしたら、その中に正例は半分見つかるはずなので、リフト値は0.5/0.5 = 1です。
ここで別の分類器があり、スコア上位25%の中で正例の割合が95%を占めたとすると、リフト値は0.95/0.5 = 1.9です。
このリフト値を、スコア上位から見ていった時のインスタンス割合を横軸にとってプロットしたものがリフト曲線です。

見ていることは累積反応曲線と同じなのですが、
「ランダムなものより"約2倍の精度"がある」という直観的な言及がしやすいため、
リフト曲線を用いることが個人的には多いです。

##リフト曲線を描くには
リフト曲線を描くためには、
各インスタンスに付与されたスコアと、正例/負例のラベルが必要です。

スコア ラベル
0.1248 no
0.3422 no
0.9123 yes
0.7239 yes
0.7423 no
0.5214 yes
・・・ ・・・

##Rで実際に描いてみる

まずは、上記のデータを得るために、
RのC50パッケージに用意されているchurnデータセットを例にモデルを作成します。
(ここでは何も考えずに、caretパッケージを使ってtrain→predictをやってしまいます)

library(dplyr)
library(C50)
library(ggplot2)
library(caret)

data(churn)

# データ確認
str(churnTrain)
str(churnTest)

# モデル作成。スコアを付けたいため、trainControlでclassProbsを適用
fit.glmnnet <- train(churn ~ ., data = churnTrain, method = "glmnet", 
                     trControl = trainControl(classProbs = T))

# 予測の実行
pred <- predict(fit.glmnnet, churnTrain, type = "prob") %>% 
  cbind(select(churnTrain,churn))

str(pred)
# 'data.frame':	3333 obs. of  3 variables:
#  $ yes  : num  0.1372 0.0537 0.1227 0.4423 0.4498 ...
#  $ no   : num  0.863 0.946 0.877 0.558 0.55 ...
#  $ churn: Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...

yesが対象となるスコア、churnが正例/負例のラベルです。
ここからリフト値を計算します。

pred.rownum <- nrow(pred)
instance.ratio <- (1:pred.rownum)/pred.rownum

# スコア降順に並べ替え
pred.arranged <- pred %>% 
  arrange(desc(yes)) %>% 
  mutate(res = (churn=="yes"))

# 該当インスタンスより上位にある正例をカウント
pred.arranged$res.cum <- cumsum(pred.arranged$res)

# リフト値を計算
lift <- 
  (pred.arranged$res.cum/(1:pred.rownum)) / 
  (pred.arranged$res.cum[pred.rownum]/pred.rownum)

# プロット。ついでに lift=1 の線も。
graph.lift <- ggplot() +
  geom_line(aes(x = instance.ratio,
                y = lift)) +
  geom_line(aes(x=c(0,1), y=c(1,1)), color = "red")
print(graph.lift)

Rplot.png

リフト曲線が描けました。
これを見れば、「"3倍の精度"が担保されるのは、インスタンス割合が約2割の地点である」ことがすぐわかり、そのときのスコアも下記のように求めることができます。

# リフト値3のインスタンス割合の算出
sum(lift>=3)/length(lift)
# 0.2118212

# 閾値となるスコアの算出
pred.arranged$yes[sum(lift>=3)]
# 0.2126706

##応用、他
複数モデルについて、モデルごとに同じ要領で計算し、色分けして描画すれば、
インスタンス割合とリフト値の関係が一目瞭然になります。

  • インスタンスを1割に絞り込むときの最良のモデルは何か
  • "5倍の精度"を担保するとき、インスタンス数を最も確保できるモデルは何か

等、複数の軸でモデルの比較検討することができます。

ただ、リフト値は、モデルの精度を示す一つの指標にすぎませんので、
他の指標についても確認する等、多角的に検討すべきであることを最後に付け加えておきます。

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?