LoginSignup
13
15

More than 5 years have passed since last update.

決定木ベースの機械学習モデル(Random Forests, XGBboost)からdefragTreesを利用して一目でわかるルールの抽出を試みる。

Last updated at Posted at 2019-02-15

はじめに

背景

近年の機械学習モデルは年々その複雑さを増しています。特にビジネスで機械学習モデルから何かを説明する際に、人間の目にとってはブラックボックス化した機械学習モデルを解釈することが困難です。そこで、そのブラックボックス化した機械学習モデルから人間の目にも一目でわかる単純なルールを取り出すことが望まれていました。

内容

本ページでは次のような決定木ベースの機械学習モデルを構築します。

  • Random Forests
  • XGBboost

その単純化されたルールの抽出を試みるために、defragTreesを利用する。
LightGBMはほぼ同じなの元のページを参考にしてください。

学銃的な背景

本ページは、Making Tree Ensembles Interpretable: A Bayesian Model Selection Approachを参考にしています。この論文のソースコードはdefragTreesを参考にしています。
また日本語でスライドを読みたい方はアンサンブル木モデル解釈のためのモデル簡略化法を参考にしてください。

インストール

defragTreesは pip コマンドや git+URLに対応していません。そのため、ここからdefragTrees.pyをダウンロードしてください。ダウンロードしたdefragTrees.pyはソースファイルと同一なフォルダに入れて「from defragTrees import DefragModel」とすれば実行することができます。

import numpy as np
from sklearn.datasets import make_classification
from imblearn.ensemble import BalancedRandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, cohen_kappa_score, balanced_accuracy_score, make_scorer, f1_score, recall_score
from sklearn.ensemble import RandomForestClassifier
import lightgbm as lgb
import xgboost as xgb
from defragTrees import DefragModel
from sklearn.externals.joblib import dump
from sklearn.externals.joblib import load

実験

テスト用のデータ作成

パラメータの値

  • n_samples: サンプルの数
    • アンケート調査であれば回答者数に該当
  • n_features: 特徴量の数
    • 分類対象(クラス)を識別するために利用する特徴量の数
    • 例えば、アンケート調査であれば質問と回答のペア
  • n_informative:分類対象と関係のある特徴量の数
    • この数が多ければ多いほど予測が簡単になる
  • n_classes: 分類対象が何種類あるか
    • 2値分類問題にするならば 2 をセット
  • random_state: ランダムシードの設定
data = make_classification(n_samples=1000, #生成するサンプル数 
                           n_features=110,  
                           n_informative=100,
                           weights=[0.7,0.3],
                           n_classes=2,
                          random_state=43)
data_set = data[0] # 特徴量
target_set = data[1] # クラスラベル
#from sklearn.datasets import load_iris
#iris = load_iris()
#data_set = iris.data
#target_set =iris.target

Random Forests

モデルの構築

model = BalancedRandomForestClassifier(random_state = 43,
                               n_jobs = 1,
                               n_estimators = 500,
                               max_features = "log2",
                               class_weight = 'balanced',
                               sampling_strategy = 'all',
                               max_depth = None, 
                               oob_score=False)
scoring = {'accuracy': make_scorer(accuracy_score),
           'kappa': make_scorer(cohen_kappa_score),
           'blanced_accuracy': make_scorer(balanced_accuracy_score) }
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=0)
scores = cross_validate(model, data_set, target_set,n_jobs = 1,
                        cv=skf, return_train_score=True,scoring=scoring)
print(scores["test_accuracy"])
print(scores["test_kappa"])
print(scores["test_blanced_accuracy"])
model.fit(data_set, target_set)
[0.735 0.765 0.795 0.765 0.815]
[0.46356275 0.49570815 0.56008584 0.5        0.59606987]
[0.77261905 0.775      0.81071429 0.7797619  0.825     ]





BalancedRandomForestClassifier(bootstrap=True, class_weight='balanced',
                criterion='gini', max_depth=None, max_features='log2',
                max_leaf_nodes=None, min_impurity_decrease=0.0,
                min_samples_leaf=2, min_samples_split=2,
                min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=1,
                oob_score=False, random_state=43, replacement=False,
                sampling_strategy='all', verbose=0, warm_start=False)

Random Forests から単純なルールを抽出する

splitter = DefragModel.parseSLtrees(model) # parse sklearn tree ensembles into the array of (feature index, threshold)
mdl = DefragModel(modeltype='classification', maxitr=100, qitr=0, tol=1e-6, restart=20, verbose=0)
mdl.fit(data_set, target_set, splitter, 10, fittype='FAB')

# results
score, cover, coll = mdl.evaluate(data_set, target_set)
print()
print('<< defragTrees >>')
print('----- Evaluated Results -----')
print('Test Error = %f' % (score,))
print('Test Coverage = %f' % (cover,))
print('Overlap = %f' % (coll,))
print()
print('----- Found Rules -----')
print(mdl)
[Seed   0] TrainingError = 0.30, K = 2
[Seed   1] TrainingError = 0.30, K = 3
[Seed   2] TrainingError = 0.30, K = 2
[Seed   3] TrainingError = 0.30, K = 3
[Seed   4] TrainingError = 0.30, K = 2
[Seed   5] TrainingError = 0.30, K = 2
[Seed   6] TrainingError = 0.30, K = 2
[Seed   7] TrainingError = 0.30, K = 2
[Seed   8] TrainingError = 0.30, K = 3
[Seed   9] TrainingError = 0.30, K = 2
[Seed  10] TrainingError = 0.30, K = 3
[Seed  11] TrainingError = 0.30, K = 2
[Seed  12] TrainingError = 0.30, K = 1
[Seed  13] TrainingError = 0.30, K = 2
[Seed  14] TrainingError = 0.30, K = 1
[Seed  15] TrainingError = 0.30, K = 2
[Seed  16] TrainingError = 0.30, K = 2
[Seed  17] TrainingError = 0.30, K = 2
[Seed  18] TrainingError = 0.30, K = 2
[Seed  19] TrainingError = 0.30, K = 3
Optimal Model >> Seed   0, TrainingError = 0.30, K = 2

<< defragTrees >>
----- Evaluated Results -----
Test Error = 0.300000
Test Coverage = 1.000000
Overlap = 0.942000

----- Found Rules -----
[Rule  1]
y = 0 when
     x_31 < 14.781519

[Rule  2]
y = 0 when
     x_13 >= -13.205372
     x_28 < 13.061854
     x_96 < 12.163099

[Otherwise]
y = 0

結果の解釈

本データを分類するために二つのルールが抽出された。

  • Rule 1
    • y = 0 when
      • x_31 < 14.781519
    • 解釈:31番目の特徴量(x_31)の値が14.781519未満のとき、クラスラベル(y)は0と分類される
  • Rule 2
    • y = 0 when
      • x_13 >= -13.205372
      • x_28 < 13.061854
      • x_96 < 12.163099
    • 13番目の特徴量(x_13)の値が-13.205372以上、28番目の特徴量(x_28)の値が13.061854未満、96番目の特徴量(x_96)の値が12.163099未満のとき、クラスラベル(y)は0と分類される
  • それ以外は クラスラベルは0となる

XGBoost

モデルの構築

model = xgb.XGBClassifier(max_depth = 50,
                           learning_rate = 0.16,
                           min_child_weight = 1,
                           n_estimators = 200)
scoring = {'accuracy': make_scorer(accuracy_score),
           'kappa': make_scorer(cohen_kappa_score),
           'blanced_accuracy': make_scorer(balanced_accuracy_score) }
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=0)
scores = cross_validate(model, data_set, target_set,n_jobs = 1,
                        cv=skf, return_train_score=True,scoring=scoring)
print(scores["test_accuracy"])
print(scores["test_kappa"])
print(scores["test_blanced_accuracy"])
num_round = 50
dtrain = xgb.DMatrix(data_set, label=target_set)
param = {'max_depth':50, 'learning_rate':0.16, 'min_child_weight':1, 'n_estimators':200}
bst = xgb.train(param, dtrain, num_round)

# output xgb model as text
bst.dump_model('xgbmodel.txt')

[0.8   0.825 0.88  0.8   0.86 ]
[0.46236559 0.54188482 0.67741935 0.44444444 0.62365591]
[0.7047619  0.74642857 0.8        0.69047619 0.77619048]
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 194 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 212 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 210 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 226 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 200 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 216 extra nodes, 0 pruned nodes, max_depth=12
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 222 extra nodes, 0 pruned nodes, max_depth=15
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 202 extra nodes, 0 pruned nodes, max_depth=12
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 206 extra nodes, 0 pruned nodes, max_depth=17
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 222 extra nodes, 0 pruned nodes, max_depth=20
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 242 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 246 extra nodes, 0 pruned nodes, max_depth=15
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 284 extra nodes, 0 pruned nodes, max_depth=22
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 312 extra nodes, 0 pruned nodes, max_depth=17
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 348 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 336 extra nodes, 0 pruned nodes, max_depth=16
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 354 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 378 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 390 extra nodes, 0 pruned nodes, max_depth=19
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 394 extra nodes, 0 pruned nodes, max_depth=19
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 412 extra nodes, 0 pruned nodes, max_depth=20
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 414 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 414 extra nodes, 0 pruned nodes, max_depth=25
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 452 extra nodes, 0 pruned nodes, max_depth=21
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=27
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=21
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 450 extra nodes, 0 pruned nodes, max_depth=22
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 470 extra nodes, 0 pruned nodes, max_depth=25
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 464 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 464 extra nodes, 0 pruned nodes, max_depth=19
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 460 extra nodes, 0 pruned nodes, max_depth=21
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=21
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 456 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 448 extra nodes, 0 pruned nodes, max_depth=25
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 416 extra nodes, 0 pruned nodes, max_depth=26
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 406 extra nodes, 0 pruned nodes, max_depth=33
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 400 extra nodes, 0 pruned nodes, max_depth=29
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 368 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 348 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 320 extra nodes, 0 pruned nodes, max_depth=26
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 300 extra nodes, 0 pruned nodes, max_depth=27
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 282 extra nodes, 0 pruned nodes, max_depth=29
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 258 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 240 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 238 extra nodes, 0 pruned nodes, max_depth=32
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 182 extra nodes, 0 pruned nodes, max_depth=22
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 176 extra nodes, 0 pruned nodes, max_depth=25

XGBoost から単純なルールを抽出する

splitter = DefragModel.parseXGBtrees('xgbmodel.txt') # parse sklearn tree ensembles into the array of (feature index, threshold)
mdl = DefragModel(modeltype='classification', maxitr=100, qitr=0, tol=1e-6, restart=20, verbose=0)
mdl.fit(data_set, target_set, splitter, 10, fittype='FAB')

# results
score, cover, coll = mdl.evaluate(data_set, target_set)
print()
print('<< defragTrees >>')
print('----- Evaluated Results -----')
print('Test Error = %f' % (score,))
print('Test Coverage = %f' % (cover,))
print('Overlap = %f' % (coll,))
print()
print('----- Found Rules -----')
print(mdl)
[Seed   0] TrainingError = 0.30, K = 3
[Seed   1] TrainingError = 0.30, K = 3
[Seed   2] TrainingError = 0.30, K = 2
[Seed   3] TrainingError = 0.30, K = 4
[Seed   4] TrainingError = 0.30, K = 2
[Seed   5] TrainingError = 0.30, K = 3
[Seed   6] TrainingError = 0.30, K = 3
[Seed   7] TrainingError = 0.30, K = 1
[Seed   8] TrainingError = 0.30, K = 3
[Seed   9] TrainingError = 0.30, K = 3
[Seed  10] TrainingError = 0.30, K = 2
[Seed  11] TrainingError = 0.30, K = 3
[Seed  12] TrainingError = 0.30, K = 3
[Seed  13] TrainingError = 0.30, K = 3
[Seed  14] TrainingError = 0.30, K = 4
[Seed  15] TrainingError = 0.30, K = 2
[Seed  16] TrainingError = 0.30, K = 2
[Seed  17] TrainingError = 0.30, K = 3
[Seed  18] TrainingError = 0.30, K = 3
[Seed  19] TrainingError = 0.30, K = 4
Optimal Model >> Seed   0, TrainingError = 0.30, K = 3

<< defragTrees >>
----- Evaluated Results -----
Test Error = 0.300000
Test Coverage = 1.000000
Overlap = 0.122000

----- Found Rules -----
[Rule  1]
y = 0 when
     x_1 < 0.389858
     x_3 >= -1.891694
     x_4 < 13.444655
     x_10 >= -15.387629
     x_12 >= -11.851440
     x_14 >= -16.041128
     x_18 >= -14.103243
     x_45 < 13.571693
     x_68 >= -11.976764
     x_89 < 16.081133

[Rule  2]
y = 0 when
     x_1 < 0.389858
     x_3 < 0.876493
     x_9 < 2.240515
     x_20 < 12.560732
     x_26 < 15.807751
     x_41 < 14.043766
     x_78 >= -15.805157
     x_81 < 14.394110

[Rule  3]
y = 0 when
     x_1 >= -0.301440

[Otherwise]
y = 0

結果の解釈

本データを分類するために3つのルールが抽出された。ルールの解釈の仕方はRandomForestsと同様に行えばよいので一つでだけ例をあげる。

  • ルール3
    • y = 0 when x_1 >= -0.301440
      • 特徴量 1 が -0.301440 以下のとき、クラスラベルは 0 と分類される

まとめ

決定木ベースの複雑な機械学習モデルを構築し、もし 特徴量が *** だったら、クラスラベルは *** に分類されるというルールを抽出できた。どれほどの効果があるかはわからないが、defragTreesを動かすことができた。

13
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
15