はじめに
背景
近年の機械学習モデルは年々その複雑さを増しています。特にビジネスで機械学習モデルから何かを説明する際に、人間の目にとってはブラックボックス化した機械学習モデルを解釈することが困難です。そこで、そのブラックボックス化した機械学習モデルから人間の目にも一目でわかる単純なルールを取り出すことが望まれていました。
内容
本ページでは次のような決定木ベースの機械学習モデルを構築します。
- Random Forests
- XGBboost
その単純化されたルールの抽出を試みるために、defragTreesを利用する。
LightGBMはほぼ同じなの元のページを参考にしてください。
学銃的な背景
本ページは、Making Tree Ensembles Interpretable: A Bayesian Model Selection Approachを参考にしています。この論文のソースコードはdefragTreesを参考にしています。
また日本語でスライドを読みたい方はアンサンブル木モデル解釈のためのモデル簡略化法を参考にしてください。
インストール
defragTreesは pip コマンドや git+URLに対応していません。そのため、ここからdefragTrees.pyをダウンロードしてください。ダウンロードしたdefragTrees.pyはソースファイルと同一なフォルダに入れて「from defragTrees import DefragModel」とすれば実行することができます。
import numpy as np
from sklearn.datasets import make_classification
from imblearn.ensemble import BalancedRandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.metrics import accuracy_score, cohen_kappa_score, balanced_accuracy_score, make_scorer, f1_score, recall_score
from sklearn.ensemble import RandomForestClassifier
import lightgbm as lgb
import xgboost as xgb
from defragTrees import DefragModel
from sklearn.externals.joblib import dump
from sklearn.externals.joblib import load
実験
テスト用のデータ作成
パラメータの値
- n_samples: サンプルの数
- アンケート調査であれば回答者数に該当
- n_features: 特徴量の数
- 分類対象(クラス)を識別するために利用する特徴量の数
- 例えば、アンケート調査であれば質問と回答のペア
- n_informative:分類対象と関係のある特徴量の数
- この数が多ければ多いほど予測が簡単になる
- n_classes: 分類対象が何種類あるか
- 2値分類問題にするならば 2 をセット
- random_state: ランダムシードの設定
data = make_classification(n_samples=1000, #生成するサンプル数
n_features=110,
n_informative=100,
weights=[0.7,0.3],
n_classes=2,
random_state=43)
data_set = data[0] # 特徴量
target_set = data[1] # クラスラベル
#from sklearn.datasets import load_iris
#iris = load_iris()
#data_set = iris.data
#target_set =iris.target
Random Forests
モデルの構築
model = BalancedRandomForestClassifier(random_state = 43,
n_jobs = 1,
n_estimators = 500,
max_features = "log2",
class_weight = 'balanced',
sampling_strategy = 'all',
max_depth = None,
oob_score=False)
scoring = {'accuracy': make_scorer(accuracy_score),
'kappa': make_scorer(cohen_kappa_score),
'blanced_accuracy': make_scorer(balanced_accuracy_score) }
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=0)
scores = cross_validate(model, data_set, target_set,n_jobs = 1,
cv=skf, return_train_score=True,scoring=scoring)
print(scores["test_accuracy"])
print(scores["test_kappa"])
print(scores["test_blanced_accuracy"])
model.fit(data_set, target_set)
[0.735 0.765 0.795 0.765 0.815]
[0.46356275 0.49570815 0.56008584 0.5 0.59606987]
[0.77261905 0.775 0.81071429 0.7797619 0.825 ]
BalancedRandomForestClassifier(bootstrap=True, class_weight='balanced',
criterion='gini', max_depth=None, max_features='log2',
max_leaf_nodes=None, min_impurity_decrease=0.0,
min_samples_leaf=2, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=1,
oob_score=False, random_state=43, replacement=False,
sampling_strategy='all', verbose=0, warm_start=False)
Random Forests から単純なルールを抽出する
splitter = DefragModel.parseSLtrees(model) # parse sklearn tree ensembles into the array of (feature index, threshold)
mdl = DefragModel(modeltype='classification', maxitr=100, qitr=0, tol=1e-6, restart=20, verbose=0)
mdl.fit(data_set, target_set, splitter, 10, fittype='FAB')
# results
score, cover, coll = mdl.evaluate(data_set, target_set)
print()
print('<< defragTrees >>')
print('----- Evaluated Results -----')
print('Test Error = %f' % (score,))
print('Test Coverage = %f' % (cover,))
print('Overlap = %f' % (coll,))
print()
print('----- Found Rules -----')
print(mdl)
[Seed 0] TrainingError = 0.30, K = 2
[Seed 1] TrainingError = 0.30, K = 3
[Seed 2] TrainingError = 0.30, K = 2
[Seed 3] TrainingError = 0.30, K = 3
[Seed 4] TrainingError = 0.30, K = 2
[Seed 5] TrainingError = 0.30, K = 2
[Seed 6] TrainingError = 0.30, K = 2
[Seed 7] TrainingError = 0.30, K = 2
[Seed 8] TrainingError = 0.30, K = 3
[Seed 9] TrainingError = 0.30, K = 2
[Seed 10] TrainingError = 0.30, K = 3
[Seed 11] TrainingError = 0.30, K = 2
[Seed 12] TrainingError = 0.30, K = 1
[Seed 13] TrainingError = 0.30, K = 2
[Seed 14] TrainingError = 0.30, K = 1
[Seed 15] TrainingError = 0.30, K = 2
[Seed 16] TrainingError = 0.30, K = 2
[Seed 17] TrainingError = 0.30, K = 2
[Seed 18] TrainingError = 0.30, K = 2
[Seed 19] TrainingError = 0.30, K = 3
Optimal Model >> Seed 0, TrainingError = 0.30, K = 2
<< defragTrees >>
----- Evaluated Results -----
Test Error = 0.300000
Test Coverage = 1.000000
Overlap = 0.942000
----- Found Rules -----
[Rule 1]
y = 0 when
x_31 < 14.781519
[Rule 2]
y = 0 when
x_13 >= -13.205372
x_28 < 13.061854
x_96 < 12.163099
[Otherwise]
y = 0
結果の解釈
本データを分類するために二つのルールが抽出された。
- Rule 1
- y = 0 when
- x_31 < 14.781519
- 解釈:31番目の特徴量(x_31)の値が14.781519未満のとき、クラスラベル(y)は0と分類される
- y = 0 when
- Rule 2
- y = 0 when
- x_13 >= -13.205372
- x_28 < 13.061854
- x_96 < 12.163099
- 13番目の特徴量(x_13)の値が-13.205372以上、28番目の特徴量(x_28)の値が13.061854未満、96番目の特徴量(x_96)の値が12.163099未満のとき、クラスラベル(y)は0と分類される
- y = 0 when
- それ以外は クラスラベルは0となる
XGBoost
モデルの構築
model = xgb.XGBClassifier(max_depth = 50,
learning_rate = 0.16,
min_child_weight = 1,
n_estimators = 200)
scoring = {'accuracy': make_scorer(accuracy_score),
'kappa': make_scorer(cohen_kappa_score),
'blanced_accuracy': make_scorer(balanced_accuracy_score) }
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=0)
scores = cross_validate(model, data_set, target_set,n_jobs = 1,
cv=skf, return_train_score=True,scoring=scoring)
print(scores["test_accuracy"])
print(scores["test_kappa"])
print(scores["test_blanced_accuracy"])
num_round = 50
dtrain = xgb.DMatrix(data_set, label=target_set)
param = {'max_depth':50, 'learning_rate':0.16, 'min_child_weight':1, 'n_estimators':200}
bst = xgb.train(param, dtrain, num_round)
# output xgb model as text
bst.dump_model('xgbmodel.txt')
[0.8 0.825 0.88 0.8 0.86 ]
[0.46236559 0.54188482 0.67741935 0.44444444 0.62365591]
[0.7047619 0.74642857 0.8 0.69047619 0.77619048]
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 194 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 212 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 210 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 226 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 208 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 200 extra nodes, 0 pruned nodes, max_depth=14
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 216 extra nodes, 0 pruned nodes, max_depth=12
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 222 extra nodes, 0 pruned nodes, max_depth=15
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 202 extra nodes, 0 pruned nodes, max_depth=12
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 206 extra nodes, 0 pruned nodes, max_depth=17
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 222 extra nodes, 0 pruned nodes, max_depth=20
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 242 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 246 extra nodes, 0 pruned nodes, max_depth=15
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 284 extra nodes, 0 pruned nodes, max_depth=22
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 312 extra nodes, 0 pruned nodes, max_depth=17
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 348 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 336 extra nodes, 0 pruned nodes, max_depth=16
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 354 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 378 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 390 extra nodes, 0 pruned nodes, max_depth=19
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 394 extra nodes, 0 pruned nodes, max_depth=19
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 412 extra nodes, 0 pruned nodes, max_depth=20
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 414 extra nodes, 0 pruned nodes, max_depth=18
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 414 extra nodes, 0 pruned nodes, max_depth=25
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 452 extra nodes, 0 pruned nodes, max_depth=21
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=27
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=21
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 450 extra nodes, 0 pruned nodes, max_depth=22
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 470 extra nodes, 0 pruned nodes, max_depth=25
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 464 extra nodes, 0 pruned nodes, max_depth=23
[19:39:56] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 464 extra nodes, 0 pruned nodes, max_depth=19
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 460 extra nodes, 0 pruned nodes, max_depth=21
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 444 extra nodes, 0 pruned nodes, max_depth=21
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 456 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 448 extra nodes, 0 pruned nodes, max_depth=25
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 416 extra nodes, 0 pruned nodes, max_depth=26
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 406 extra nodes, 0 pruned nodes, max_depth=33
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 400 extra nodes, 0 pruned nodes, max_depth=29
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 368 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 348 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 320 extra nodes, 0 pruned nodes, max_depth=26
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 300 extra nodes, 0 pruned nodes, max_depth=27
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 282 extra nodes, 0 pruned nodes, max_depth=29
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 258 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 240 extra nodes, 0 pruned nodes, max_depth=24
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 238 extra nodes, 0 pruned nodes, max_depth=32
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 182 extra nodes, 0 pruned nodes, max_depth=22
[19:39:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 176 extra nodes, 0 pruned nodes, max_depth=25
XGBoost から単純なルールを抽出する
splitter = DefragModel.parseXGBtrees('xgbmodel.txt') # parse sklearn tree ensembles into the array of (feature index, threshold)
mdl = DefragModel(modeltype='classification', maxitr=100, qitr=0, tol=1e-6, restart=20, verbose=0)
mdl.fit(data_set, target_set, splitter, 10, fittype='FAB')
# results
score, cover, coll = mdl.evaluate(data_set, target_set)
print()
print('<< defragTrees >>')
print('----- Evaluated Results -----')
print('Test Error = %f' % (score,))
print('Test Coverage = %f' % (cover,))
print('Overlap = %f' % (coll,))
print()
print('----- Found Rules -----')
print(mdl)
[Seed 0] TrainingError = 0.30, K = 3
[Seed 1] TrainingError = 0.30, K = 3
[Seed 2] TrainingError = 0.30, K = 2
[Seed 3] TrainingError = 0.30, K = 4
[Seed 4] TrainingError = 0.30, K = 2
[Seed 5] TrainingError = 0.30, K = 3
[Seed 6] TrainingError = 0.30, K = 3
[Seed 7] TrainingError = 0.30, K = 1
[Seed 8] TrainingError = 0.30, K = 3
[Seed 9] TrainingError = 0.30, K = 3
[Seed 10] TrainingError = 0.30, K = 2
[Seed 11] TrainingError = 0.30, K = 3
[Seed 12] TrainingError = 0.30, K = 3
[Seed 13] TrainingError = 0.30, K = 3
[Seed 14] TrainingError = 0.30, K = 4
[Seed 15] TrainingError = 0.30, K = 2
[Seed 16] TrainingError = 0.30, K = 2
[Seed 17] TrainingError = 0.30, K = 3
[Seed 18] TrainingError = 0.30, K = 3
[Seed 19] TrainingError = 0.30, K = 4
Optimal Model >> Seed 0, TrainingError = 0.30, K = 3
<< defragTrees >>
----- Evaluated Results -----
Test Error = 0.300000
Test Coverage = 1.000000
Overlap = 0.122000
----- Found Rules -----
[Rule 1]
y = 0 when
x_1 < 0.389858
x_3 >= -1.891694
x_4 < 13.444655
x_10 >= -15.387629
x_12 >= -11.851440
x_14 >= -16.041128
x_18 >= -14.103243
x_45 < 13.571693
x_68 >= -11.976764
x_89 < 16.081133
[Rule 2]
y = 0 when
x_1 < 0.389858
x_3 < 0.876493
x_9 < 2.240515
x_20 < 12.560732
x_26 < 15.807751
x_41 < 14.043766
x_78 >= -15.805157
x_81 < 14.394110
[Rule 3]
y = 0 when
x_1 >= -0.301440
[Otherwise]
y = 0
結果の解釈
本データを分類するために3つのルールが抽出された。ルールの解釈の仕方はRandomForestsと同様に行えばよいので一つでだけ例をあげる。
- ルール3
- y = 0 when x_1 >= -0.301440
- 特徴量 1 が -0.301440 以下のとき、クラスラベルは 0 と分類される
- y = 0 when x_1 >= -0.301440
まとめ
決定木ベースの複雑な機械学習モデルを構築し、もし 特徴量が *** だったら、クラスラベルは *** に分類されるというルールを抽出できた。どれほどの効果があるかはわからないが、defragTreesを動かすことができた。