49
36

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

BrainPadAdvent Calendar 2017

Day 7

xgboostで小さいカテゴリもちゃんと分類するテクニック「sample weights」

Last updated at Posted at 2017-12-06

#はじめに
今お仕事でカテゴリ分類の予測モデルを構築しています。
例えば、ビールのような多ブランド展開をしているような商品において「今Aブランドを好んで飲んでいる人が、半年後はどのブランドを飲んでいそうか?」ということを当てるようなことをやっています。

で、この予測モデル、ただ単に精度が高ければ良いわけではなく「マイナーなブランドの分類精度もある程度担保してほしい」というビジネス上のオーダーがありました。不均衡データでそのまま分類モデルを作ると、どうしてもメジャーなブランドへの予測確率が高くなるように予測されやすくなるので、それは避けてほしい、ということでした。

手法はあまり複雑なことや色々な手法を試している暇が無いので、コンペでお馴染みのxgboostでやるとして、その際に上記のオーダーを満たすために使っているテクニックとして「sample weights」を使用しています。実際のkaggle等のコンペでも、小カテゴリの精度が全体の評価指標に大きく影響しそうな場合はよく使われているテクニックです。今回の記事では、sample weightsの説明と実際の分類がどう変わるのか、をご紹介します。

#不均衡データの分類方法
不均衡データにおけるsamplingを見ると、
大きく不均衡データには以下のようなアプローチがあります

  • algorithm-level approaches
    • 不均衡を調整する係数をモデルに導入する(コスト関数を調整する)
  • data-level approaches
    • 多数派データを減少させ、少数派データを増加させる手法。順序については、undersampling, oversamplingの順に行う。
    • Rでは「SMOTE」などのパッケージが有名で、手軽。
        

今回はxgboostにおいてalgorithm-level approachesである「sample weights」に関して、何も考えず分類した場合とweightを設定した場合を分けて結果を確認してみようと思います。

#xgboostとは
わざわざこの記事で説明する必要はないですね。WEBで調べれば山ほど参考資料が出てきます。qiita記事では、以下がとても参考になりました。
XGBoostの主な特徴と理論の概要

さらに深く知りたい方は以下のpdfがおすすめです。(英語)
Introduction to Boosted Trees

#sample weightsとは
厳密性は捨て、できるだけ数式を使わないざっくりとした説明をします。

まず、そもそもなぜxgboostで小カテゴリが予測されづらくなるのでしょうか。

xgboostでは木を分割するとき、情報損失の減少幅がもっとも少なくなるように分割していきます。そして、その損失の減少幅を計算するときは損失の勾配(grad/hessian)の大きさから計算していきます。このとき、目的変数がなかなか出てこないレアなカテゴリであった場合は情報損失の減少幅にあまり影響を及ぼしません。これが「小カテゴリが予測されづらくなる」原因になります。

これを防ぐため、xgboostではサンプル毎に重みをつけることができます。それが「sample weight」と呼んでいるものです。小カテゴリに所属するサンプルに重みを高くつけてあげることで勾配を大きくし、情報損失を恣意的に大きくしてあげることが可能になります。
もう少しわかりやすく言うと、小カテゴリの分類が間違った時のペナルティを大きくしてあげることで、小カテゴリを分類しやすくしているのです。このペナルティを「sample weights」で設定できるのです。

以上、数式を使わないめちゃくちゃざっくりとした説明になりますが、詳しく知りたい方は上記のxgboostの説明に加えて以下のリンクをご確認ください。
How does xgboost handle instance weights

#コード例
では、カテゴリ分類を普通に行った場合とweightsを付けた場合に、全体の予測精度とそれぞれのカテゴリの分類精度がどう変わるのかを確かめてみましょう。
データセットはkaggleのカテゴリ分類の有名コンペである[Otto Group Product Classification Challenge]
(https://www.kaggle.com/c/otto-group-product-classification-challenge/data)のデータセットを用います

library(xgboost)
library(caret)
library(data.table)

# データの取り込み
data <- fread("C:/…/train.csv",header=TRUE,stringsAsFactors=TRUE,data.table = F)

# カテゴリ名変更(xgboostのカテゴリ名は数値の0~カテゴリ数-1に揃える必要がある)
for (i in 1:9){
  data$target_num[data$target == paste('Class_',i,sep = '')] <- i-1
  }

#データ型変更
cat_num <- nrow(table(data[,c("target_num")]))
data <- data[,-c(1,95)]
data[] <- lapply(data, as.numeric)

目的変数となる各カテゴリの分布は以下のようになります。
お世辞にも各カテゴリが綺麗に分布しているとは言えませんね。

Rplot.png

そのままカテゴリ分類した場合

まずは、何も考えずにカテゴリ分類してみます。

set.seed(100)
sample_num <- sample(nrow(data), floor(nrow(data) * 0.2))
data_mod <- data[sample_num,]
target_mod <- data_mod$target_num
data_mod <- data_mod[,-94]
data_mod2 <- xgb.DMatrix(data.matrix(data_mod),label = target_mod)

data_pred <- data[-sample_num,]
target_pred <- data_pred$target_num
data_pred <- data_pred[,-94]
data_pred2 <- xgb.DMatrix(data.matrix(data_pred),label = target_pred)

# model_param
param <- list("objective" = "multi:softmax" 
              ,"eval_metric" = "merror" 
              ,"eta" = 0.1
              ,"max_depth" = 6
              ,"min_child_weight" = 3
              ,"subsample" = 1
              ,"colsample_bytree" = 1
              ,"num_class" = cat_num
              )

# 木の数を5hold-CVで決定
model.cv = xgb.cv(param=param, 
                  data = data_mod2,
                  nfold = 5,
                  nrounds = 100,
                  early_stopping_rounds = 5
)
nround <- model.cv$niter - 5 #ベストな学習数

# モデル構築・予測
model = xgboost(param=param
                ,data=data_mod2
                ,nrounds=nround
)
pred <- predict(object=model,newdata=data_pred2)

最初に、全体の分類精度を見てみましょう。

# 精度の確認
confusionMatrix(pred, target_pred)

Confusion Matrix and Statistics

          Reference
Prediction     0     1     2     3     4     5     6     7     8
         0   553     7     0     1     0    57    39    75    95
         1   100 11129  3573   988    64   161   292   128   159
         2    12  1440  2596   333     7    27   198    22     9
         3     4   100    76   702     1    16    39     0     3
         4    11    36     1    16  2130     0    15     6    14
         5   151    38     9    90     8 10531   171   207   152
         6    65    86    96    33     5   153  1354    72    27
         7   275    23    21     3     0   214   146  6170   166
         8   372    15    10     4     4   154    16   111  3316

Overall Statistics
                                         
               Accuracy : 0.7773         
                 95% CI : (0.7737, 0.781)
    No Information Rate : 0.2601         
    P-Value [Acc > NIR] : < 2.2e-16      
                                         
                  Kappa : 0.7276         
 Mcnemar's Test P-Value : < 2.2e-16      

Statistics by Class:

                     Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity           0.35839   0.8645  0.40677  0.32350  0.95989   0.9309  0.59648   0.9086  0.84141
Specificity           0.99429   0.8508  0.95251  0.99495  0.99791   0.9784  0.98863   0.9801  0.98494
Pos Pred Value        0.66868   0.6707  0.55900  0.74601  0.95559   0.9273  0.71602   0.8792  0.82859
Neg Pred Value        0.97966   0.9470  0.91560  0.96977  0.99812   0.9795  0.98076   0.9854  0.98626
Prevalence            0.03117   0.2601  0.12892  0.04384  0.04483   0.2285  0.04586   0.1372  0.07961
Detection Rate        0.01117   0.2248  0.05244  0.01418  0.04303   0.2127  0.02735   0.1246  0.06699
Detection Prevalence  0.01671   0.3352  0.09381  0.01901  0.04503   0.2294  0.03820   0.1418  0.08084
Balanced Accuracy     0.67634   0.8576  0.67964  0.65923  0.97890   0.9546  0.79255   0.9444  0.91318

全体のaccuracyは0.78,kappa係数も0.73と、カテゴリ数が多い割にはなかなかの分類精度では無いでしょうか。しかし、小カテゴリに目を向けてみると、数が少ないclass0,3,6のSensivity(=recall)が低いことがわかります。

次に「予測されたカテゴリ」と「実際のカテゴリ」の数を比較してみましょう。
赤が実際の各カテゴリの数、緑が予測された各カテゴリの数を示します。

大カテゴリのカテゴリ1は実際より多く分類されているのに比較して、小カテゴリのカテゴリ0,3,6は実際より数が少なめに予測されていますね。あとカテゴリ2も分類精度が高くないので少な目に予測されていることがわかります。
Rplot01.png

weightをつけて分類した場合

weightは、xgb.DMatrixの引数として設定することができます。
今回はざっくり小さいカテゴリの重みを「4~3」、中くらいのカテゴリの重みを「2~3」、大きなカテゴリの重みを「1~2」としてみました。細かい数値は結果を見ながらちょこちょこ調整しています。

# weight付け
weight <- as.data.frame(floor(1 / (table(data[,94]) / nrow(data))))
names(weight) <- c("target_num","weight") 

# 各カテゴリのweight
weight$weight <- c(4,1,1.5,3,3,1.2,3,2,2.5)

weight$target_num <- as.numeric(weight$target_num) - 1
weight_data <- dplyr::left_join(data
                           ,weight
                           ,by= c("target_num")
                          ) 
weight_data <- weight_data[,95]

# weightをつけてxgb.DMatrixを作成
data_mod2 <- xgb.DMatrix(data.matrix(data_mod)
                         ,label = target_mod
                         ,weight = weight_mod
                         )
data_pred2 <- xgb.DMatrix(data.matrix(data_pred)
                         ,label = target_pred
                         ,weight = weight_pred
                         )

# 後は省略

全体の精度は以下のようになりました。前と比べて、Accuracyは0.7773 → 0.7733と若干下がっていますが、ほぼ変わりありませんね。また、大カテゴリ(category1)の分類精度が下がっている代わりに、さきほど分類精度がイマイチだった小カテゴリ(category0,3,6)の分類精度が上がっていることがわかります。


Confusion Matrix and Statistics

          Reference
Prediction     0     1     2     3     4     5     6     7     8
         0   773    42    12     4     6   145    88   210   199
         1    41  9660  2429   639    45    97   140    69    76
         2    20  2425  3472   407    10    30   161    32    10
         3     5   416   263   967     2    40    81     0    13
         4    15    71     5    18  2143     6    14    10    23
         5    87    25     4    68     5 10295   107   135   103
         6    61   160   148    52     4   210  1517   103    30
         7   205    38    26     6     0   264   139  6095   127
         8   336    37    23     9     4   226    23   137  3360

Overall Statistics
                                         
               Accuracy : 0.7733         
                 95% CI : (0.7696, 0.777)
    No Information Rate : 0.2601         
    P-Value [Acc > NIR] : < 2.2e-16      
                                         
                  Kappa : 0.7272         
 Mcnemar's Test P-Value : < 2.2e-16      

Statistics by Class:

                     Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity           0.50097   0.7503  0.54403  0.44562  0.96575   0.9100  0.66828   0.8975  0.85258
Specificity           0.98528   0.9035  0.92823  0.98268  0.99657   0.9860  0.98374   0.9812  0.98255
Pos Pred Value        0.52265   0.7320  0.52870  0.54113  0.92972   0.9507  0.66389   0.8833  0.80866
Neg Pred Value        0.98397   0.9115  0.93222  0.97479  0.99839   0.9737  0.98405   0.9837  0.98719
Prevalence            0.03117   0.2601  0.12892  0.04384  0.04483   0.2285  0.04586   0.1372  0.07961
Detection Rate        0.01562   0.1951  0.07014  0.01953  0.04329   0.2080  0.03064   0.1231  0.06787
Detection Prevalence  0.02988   0.2666  0.13266  0.03610  0.04656   0.2188  0.04616   0.1394  0.08393
Balanced Accuracy     0.74313   0.8269  0.73613  0.71415  0.98116   0.9480  0.82601   0.9393  0.91756

次に「予測されたカテゴリ」と「実際のカテゴリ」の数を比較すると、さきほどの例と違い、予測されたカテゴリと実際のカテゴリの数がかなり近づいていることがわかります。おそらく、小カテゴリが分類されやすくなったことによる影響だと思われます。

Rplot02.png

#まとめ
分類問題で、小カテゴリの分類精度や分類の偏りを無くしたい場合、ビジネスの現場ならsample weightsの調整がかなり良さそうかなと思っています。何より手軽だし、重みを調整して試行錯誤しやすいのも良いですね。皆さまも、同じようなシチュエーションに遭遇したらぜひ使ってみてください。

#補足
小カテゴリとはいえ、サンプルがある程度確保できるようならdownsampling+baggingでも良いかもしれません。
不均衡データをdownsampling + baggingで補正すると汎化性能も確保できて良さそう

49
36
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
49
36

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?