LoginSignup
31
25

More than 5 years have passed since last update.

XGBoostの実装(R編)

Last updated at Posted at 2016-03-30

XGBoostとは

XGBoostとは,DMLCによって開発されているGradient Tree Boostingを実行するライブラリです.
C++, R, python, JuliaそしてJavaのライブラリが公開されています.
XGBoostとは,eXtreme Gradient Boostingの略称です.

Boosted trees

XGBoostでは,Random Forestの学習アルゴリズムを利用して教師あり学習を行う,Boosted treesが実行できます.
至極単純に言うと,Gradient Boosting(重み付きアンサンブル学習)とRandom Forest(決定木のアンサンブル学習)を組み合わせたアルゴリズムです(そのまんま).
理論については,ドキュメントにも書かれていますが,そのうち別記事で書く予定です.

XGBoost for R

XGBoostのRパッケージについては,CRANから入手できます.
ビルドが不要なため,install.packageのコマンドから簡単に利用できます.

実装

各種パッケージの入手

install.package("xgboost")
install.package(c("data.table","dplyr")) ## 入力データを変換するために必要
require(xgboost)
require(Matrix) ## dgCMatrixへ変換するために必要
require(data.table)

学習データの準備

今回はirisのデータセットを使用して,多クラス分類の学習,推定を行います.
操作はe1071パッケージのsvmなどを利用する時とほとんど同じなため,容易に実装できます.

入力データですが,一度data.table形式に変換し,更にdgCMatrix形式へ再変換する必要があります.
まず,irisの元データをdata.table形式に変換し,ラベルとなるSpeciesラベルをテーブルから除去します.

data(iris)

## irisの元データをdata.table形式に変換
df<-data.table(iris,keep.rownames=F)

## (変換前)
head(df)

   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
 1:          5.1         3.5          1.4         0.2  setosa
 2:          4.9         3.0          1.4         0.2  setosa
 3:          4.7         3.2          1.3         0.2  setosa
 4:          4.6         3.1          1.5         0.2  setosa
 5:          5.0         3.6          1.4         0.2  setosa
 6:          5.4         3.9          1.7         0.4  setosa

## Speciesラベルのデータを除去
df[,Species:=NULL]

## (変換後)
head(df)

   Sepal.Length Sepal.Width Petal.Length Petal.Width
1:          5.1         3.5          1.4         0.2
2:          4.9         3.0          1.4         0.2
3:          4.7         3.2          1.3         0.2
4:          4.6         3.1          1.5         0.2
5:          5.0         3.6          1.4         0.2
6:          5.4         3.9          1.7         0.4

data.tableからdgCMatrixに,ラベルを整数値に,それぞれ変換します.

## 先頭行とSpeciesラベルのデータを除いてdgCMatrix形式に変換
sparse_matrix<-sparse.model.matrix(Species~.-1,data=df)
## 3つのラベルを整数値に変換(0,1,2)
output_vector=as.integer(as.factor(iris$Species))-1

学習

xgboostで学習を行います.
irisは3クラスあるので,多クラス分類(objective="multi:softmax")を選択します.

## 多クラス分類
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 4, eta = 1, nthread = 2, nround = 5, num_class=3, objective = "multi:softmax", verbose = 1)

[0] train-merror:0.020000
[1] train-merror:0.013333
[2] train-merror:0.006667
[3] train-merror:0.006667
[4] train-merror:0.000000

verbose > 1とすることで,各ステップで作成されたツリーの詳細が確認できます.

## 多クラス分類(ツリーの詳細)
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 4, eta = 1, nthread = 2, nround = 5, num_class=3, objective = "multi:softmax", verbose = 2)

tree prunning end, 1 roots, 2 extra nodes, 0 pruned nodes ,max_depth=1
tree prunning end, 1 roots, 12 extra nodes, 0 pruned nodes ,max_depth=4
tree prunning end, 1 roots, 4 extra nodes, 0 pruned nodes ,max_depth=2
[0] train-merror:0.020000
tree prunning end, 1 roots, 2 extra nodes, 0 pruned nodes ,max_depth=1
tree prunning end, 1 roots, 10 extra nodes, 0 pruned nodes ,max_depth=4
tree prunning end, 1 roots, 8 extra nodes, 0 pruned nodes ,max_depth=3
[1] train-merror:0.013333
tree prunning end, 1 roots, 2 extra nodes, 0 pruned nodes ,max_depth=1
tree prunning end, 1 roots, 10 extra nodes, 0 pruned nodes ,max_depth=4
tree prunning end, 1 roots, 8 extra nodes, 0 pruned nodes ,max_depth=3
[2] train-merror:0.006667
tree prunning end, 1 roots, 2 extra nodes, 0 pruned nodes ,max_depth=1
tree prunning end, 1 roots, 8 extra nodes, 0 pruned nodes ,max_depth=4
tree prunning end, 1 roots, 10 extra nodes, 0 pruned nodes ,max_depth=3
[3] train-merror:0.006667
tree prunning end, 1 roots, 2 extra nodes, 0 pruned nodes ,max_depth=1
tree prunning end, 1 roots, 6 extra nodes, 0 pruned nodes ,max_depth=3
tree prunning end, 1 roots, 6 extra nodes, 0 pruned nodes ,max_depth=2
[4] train-merror:0.000000

その他,細かなパラメータ設定が可能で,チューニングも容易です.
各パラメータの詳細は,こちらをご確認下さい.

推定結果

prds<-predict(bst,sparse_matrix) ## 推定するデータは教師データと同じ
table(prds,iris$Species)

prds setosa versicolor virginica
   0     50          0         0
   1      0         50         0
   2      0          0        50
31
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
31
25