#はじめに
今お仕事でカテゴリ分類の予測モデルを構築しています。
例えば、ビールのような多ブランド展開をしているような商品において「今Aブランドを好んで飲んでいる人が、半年後はどのブランドを飲んでいそうか?」ということを当てるようなことをやっています。
で、この予測モデル、ただ単に精度が高ければ良いわけではなく「マイナーなブランドの分類精度もある程度担保してほしい」というビジネス上のオーダーがありました。不均衡データでそのまま分類モデルを作ると、どうしてもメジャーなブランドへの予測確率が高くなるように予測されやすくなるので、それは避けてほしい、ということでした。
手法はあまり複雑なことや色々な手法を試している暇が無いので、コンペでお馴染みのxgboostでやるとして、その際に上記のオーダーを満たすために使っているテクニックとして「sample weights」を使用しています。実際のkaggle等のコンペでも、小カテゴリの精度が全体の評価指標に大きく影響しそうな場合はよく使われているテクニックです。今回の記事では、sample weightsの説明と実際の分類がどう変わるのか、をご紹介します。
#不均衡データの分類方法
不均衡データにおけるsamplingを見ると、
大きく不均衡データには以下のようなアプローチがあります
- algorithm-level approaches
- 不均衡を調整する係数をモデルに導入する(コスト関数を調整する)
- data-level approaches
- 多数派データを減少させ、少数派データを増加させる手法。順序については、undersampling, oversamplingの順に行う。
- Rでは「SMOTE」などのパッケージが有名で、手軽。
今回はxgboostにおいてalgorithm-level approachesである「sample weights」に関して、何も考えず分類した場合とweightを設定した場合を分けて結果を確認してみようと思います。
#xgboostとは
わざわざこの記事で説明する必要はないですね。WEBで調べれば山ほど参考資料が出てきます。qiita記事では、以下がとても参考になりました。
XGBoostの主な特徴と理論の概要
さらに深く知りたい方は以下のpdfがおすすめです。(英語)
Introduction to Boosted Trees
#sample weightsとは
厳密性は捨て、できるだけ数式を使わないざっくりとした説明をします。
まず、そもそもなぜxgboostで小カテゴリが予測されづらくなるのでしょうか。
xgboostでは木を分割するとき、情報損失の減少幅がもっとも少なくなるように分割していきます。そして、その損失の減少幅を計算するときは損失の勾配(grad/hessian)の大きさから計算していきます。このとき、目的変数がなかなか出てこないレアなカテゴリであった場合は情報損失の減少幅にあまり影響を及ぼしません。これが「小カテゴリが予測されづらくなる」原因になります。
これを防ぐため、xgboostではサンプル毎に重みをつけることができます。それが「sample weight」と呼んでいるものです。小カテゴリに所属するサンプルに重みを高くつけてあげることで勾配を大きくし、情報損失を恣意的に大きくしてあげることが可能になります。
もう少しわかりやすく言うと、小カテゴリの分類が間違った時のペナルティを大きくしてあげることで、小カテゴリを分類しやすくしているのです。このペナルティを「sample weights」で設定できるのです。
以上、数式を使わないめちゃくちゃざっくりとした説明になりますが、詳しく知りたい方は上記のxgboostの説明に加えて以下のリンクをご確認ください。
How does xgboost handle instance weights
#コード例
では、カテゴリ分類を普通に行った場合とweightsを付けた場合に、全体の予測精度とそれぞれのカテゴリの分類精度がどう変わるのかを確かめてみましょう。
データセットはkaggleのカテゴリ分類の有名コンペである[Otto Group Product Classification Challenge]
(https://www.kaggle.com/c/otto-group-product-classification-challenge/data)のデータセットを用います
library(xgboost)
library(caret)
library(data.table)
# データの取り込み
data <- fread("C:/…/train.csv",header=TRUE,stringsAsFactors=TRUE,data.table = F)
# カテゴリ名変更(xgboostのカテゴリ名は数値の0~カテゴリ数-1に揃える必要がある)
for (i in 1:9){
data$target_num[data$target == paste('Class_',i,sep = '')] <- i-1
}
#データ型変更
cat_num <- nrow(table(data[,c("target_num")]))
data <- data[,-c(1,95)]
data[] <- lapply(data, as.numeric)
目的変数となる各カテゴリの分布は以下のようになります。
お世辞にも各カテゴリが綺麗に分布しているとは言えませんね。
そのままカテゴリ分類した場合
まずは、何も考えずにカテゴリ分類してみます。
set.seed(100)
sample_num <- sample(nrow(data), floor(nrow(data) * 0.2))
data_mod <- data[sample_num,]
target_mod <- data_mod$target_num
data_mod <- data_mod[,-94]
data_mod2 <- xgb.DMatrix(data.matrix(data_mod),label = target_mod)
data_pred <- data[-sample_num,]
target_pred <- data_pred$target_num
data_pred <- data_pred[,-94]
data_pred2 <- xgb.DMatrix(data.matrix(data_pred),label = target_pred)
# model_param
param <- list("objective" = "multi:softmax"
,"eval_metric" = "merror"
,"eta" = 0.1
,"max_depth" = 6
,"min_child_weight" = 3
,"subsample" = 1
,"colsample_bytree" = 1
,"num_class" = cat_num
)
# 木の数を5hold-CVで決定
model.cv = xgb.cv(param=param,
data = data_mod2,
nfold = 5,
nrounds = 100,
early_stopping_rounds = 5
)
nround <- model.cv$niter - 5 #ベストな学習数
# モデル構築・予測
model = xgboost(param=param
,data=data_mod2
,nrounds=nround
)
pred <- predict(object=model,newdata=data_pred2)
最初に、全体の分類精度を見てみましょう。
# 精度の確認
confusionMatrix(pred, target_pred)
Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8
0 553 7 0 1 0 57 39 75 95
1 100 11129 3573 988 64 161 292 128 159
2 12 1440 2596 333 7 27 198 22 9
3 4 100 76 702 1 16 39 0 3
4 11 36 1 16 2130 0 15 6 14
5 151 38 9 90 8 10531 171 207 152
6 65 86 96 33 5 153 1354 72 27
7 275 23 21 3 0 214 146 6170 166
8 372 15 10 4 4 154 16 111 3316
Overall Statistics
Accuracy : 0.7773
95% CI : (0.7737, 0.781)
No Information Rate : 0.2601
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.7276
Mcnemar's Test P-Value : < 2.2e-16
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity 0.35839 0.8645 0.40677 0.32350 0.95989 0.9309 0.59648 0.9086 0.84141
Specificity 0.99429 0.8508 0.95251 0.99495 0.99791 0.9784 0.98863 0.9801 0.98494
Pos Pred Value 0.66868 0.6707 0.55900 0.74601 0.95559 0.9273 0.71602 0.8792 0.82859
Neg Pred Value 0.97966 0.9470 0.91560 0.96977 0.99812 0.9795 0.98076 0.9854 0.98626
Prevalence 0.03117 0.2601 0.12892 0.04384 0.04483 0.2285 0.04586 0.1372 0.07961
Detection Rate 0.01117 0.2248 0.05244 0.01418 0.04303 0.2127 0.02735 0.1246 0.06699
Detection Prevalence 0.01671 0.3352 0.09381 0.01901 0.04503 0.2294 0.03820 0.1418 0.08084
Balanced Accuracy 0.67634 0.8576 0.67964 0.65923 0.97890 0.9546 0.79255 0.9444 0.91318
全体のaccuracyは0.78,kappa係数も0.73と、カテゴリ数が多い割にはなかなかの分類精度では無いでしょうか。しかし、小カテゴリに目を向けてみると、数が少ないclass0,3,6のSensivity(=recall)が低いことがわかります。
次に「予測されたカテゴリ」と「実際のカテゴリ」の数を比較してみましょう。
赤が実際の各カテゴリの数、緑が予測された各カテゴリの数を示します。
大カテゴリのカテゴリ1は実際より多く分類されているのに比較して、小カテゴリのカテゴリ0,3,6は実際より数が少なめに予測されていますね。あとカテゴリ2も分類精度が高くないので少な目に予測されていることがわかります。
weightをつけて分類した場合
weightは、xgb.DMatrixの引数として設定することができます。
今回はざっくり小さいカテゴリの重みを「4~3」、中くらいのカテゴリの重みを「2~3」、大きなカテゴリの重みを「1~2」としてみました。細かい数値は結果を見ながらちょこちょこ調整しています。
# weight付け
weight <- as.data.frame(floor(1 / (table(data[,94]) / nrow(data))))
names(weight) <- c("target_num","weight")
# 各カテゴリのweight
weight$weight <- c(4,1,1.5,3,3,1.2,3,2,2.5)
weight$target_num <- as.numeric(weight$target_num) - 1
weight_data <- dplyr::left_join(data
,weight
,by= c("target_num")
)
weight_data <- weight_data[,95]
# weightをつけてxgb.DMatrixを作成
data_mod2 <- xgb.DMatrix(data.matrix(data_mod)
,label = target_mod
,weight = weight_mod
)
data_pred2 <- xgb.DMatrix(data.matrix(data_pred)
,label = target_pred
,weight = weight_pred
)
# 後は省略
全体の精度は以下のようになりました。前と比べて、Accuracyは0.7773 → 0.7733と若干下がっていますが、ほぼ変わりありませんね。また、大カテゴリ(category1)の分類精度が下がっている代わりに、さきほど分類精度がイマイチだった小カテゴリ(category0,3,6)の分類精度が上がっていることがわかります。
Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8
0 773 42 12 4 6 145 88 210 199
1 41 9660 2429 639 45 97 140 69 76
2 20 2425 3472 407 10 30 161 32 10
3 5 416 263 967 2 40 81 0 13
4 15 71 5 18 2143 6 14 10 23
5 87 25 4 68 5 10295 107 135 103
6 61 160 148 52 4 210 1517 103 30
7 205 38 26 6 0 264 139 6095 127
8 336 37 23 9 4 226 23 137 3360
Overall Statistics
Accuracy : 0.7733
95% CI : (0.7696, 0.777)
No Information Rate : 0.2601
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.7272
Mcnemar's Test P-Value : < 2.2e-16
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity 0.50097 0.7503 0.54403 0.44562 0.96575 0.9100 0.66828 0.8975 0.85258
Specificity 0.98528 0.9035 0.92823 0.98268 0.99657 0.9860 0.98374 0.9812 0.98255
Pos Pred Value 0.52265 0.7320 0.52870 0.54113 0.92972 0.9507 0.66389 0.8833 0.80866
Neg Pred Value 0.98397 0.9115 0.93222 0.97479 0.99839 0.9737 0.98405 0.9837 0.98719
Prevalence 0.03117 0.2601 0.12892 0.04384 0.04483 0.2285 0.04586 0.1372 0.07961
Detection Rate 0.01562 0.1951 0.07014 0.01953 0.04329 0.2080 0.03064 0.1231 0.06787
Detection Prevalence 0.02988 0.2666 0.13266 0.03610 0.04656 0.2188 0.04616 0.1394 0.08393
Balanced Accuracy 0.74313 0.8269 0.73613 0.71415 0.98116 0.9480 0.82601 0.9393 0.91756
次に「予測されたカテゴリ」と「実際のカテゴリ」の数を比較すると、さきほどの例と違い、予測されたカテゴリと実際のカテゴリの数がかなり近づいていることがわかります。おそらく、小カテゴリが分類されやすくなったことによる影響だと思われます。
#まとめ
分類問題で、小カテゴリの分類精度や分類の偏りを無くしたい場合、ビジネスの現場ならsample weightsの調整がかなり良さそうかなと思っています。何より手軽だし、重みを調整して試行錯誤しやすいのも良いですね。皆さまも、同じようなシチュエーションに遭遇したらぜひ使ってみてください。
#補足
小カテゴリとはいえ、サンプルがある程度確保できるようならdownsampling+baggingでも良いかもしれません。
不均衡データをdownsampling + baggingで補正すると汎化性能も確保できて良さそう