はじめに
Scikit-learnの決定木の分岐情報は一般的にGraphVizを使用した可視化によって理解されます。
しかし、この方法だと人間が直感的に理解するには良いものの、
- データとして決定パスの情報を利用できない
-
GraphVizとの連携が少し煩雑なためサクッと可視化できない
といった問題をはらんでいます。
前者は、大量の決定木の分岐条件をデータとして扱いたい場合や、最終的な分岐条件をビジネス的な理由で全てテキスト化せねばならない場合などに人力では対処できずに苦しむ結果になりますし、
後者はGraphVizがなかなか入らない初心者や、新しいソフトウェアを入れづらい厄介な分析環境を使用せねばならない場合に苦しむことになります。
そこで、できればGraphVizを使用した可視化に頼らず、直接データとして分岐情報を扱う方法を知りたいと思っていました。
そこで今回はScikit-learnの決定木オブジェクト自体から分岐情報を抽出し、以下のように各ノードの最終的な分岐条件をデータとして格納する方法を模索します。
node | 条件1 | 条件2 |
---|---|---|
1 | $x\le 1$ | - |
2 | $1\lt x \le3$ | $4\lt x$ |
3 | $1\lt x \le3$ | $x \le 4$ |
4 | $3\lt x$ | - |
全体の実行ipynbはこちら。
先人たちの軌跡
-
Understanding the decision tree structure
- 公式のScikit-learnサイトにおいて決定木の構造をテキストで可視化する方法を紹介しています。本稿ではこちらを主に参考にして分岐条件の抽出を行います。
-
[Python]Graphviz不要の決定木可視化ライブラリdtreepltをつくった
- 直接試してはいませんが、GraphVizを使用せずにMatplotlibで類似の画像を描画する試みも存在します。テキストでの可視化は直感的ではないので、理解のためには画像での描画は必須ですね。
-
dtreeviz : Decision Tree Visualization
- こちらは論点がずれますが、GraphVizを使用しつつ、ノードにおける分岐条件をよりリッチにした描画を可能にしたものです。ただ、情報がリッチになりすぎる分可読性は下がるので使いどころが難しい気もします。
決定木オブジェクトからの分岐条件の抽出
それでは具体的な内容へと移ります。
上でも紹介しましたが、Scikit-learnの公式サイトを漁ってみると、"Understanding the decision tree structure"という解説サイトがあります。
こちらによると、決定木オブジェクトにおける分岐情報は決定木オブジェクトの上位階層tree_におけるいくつかの属性にノード数サイズのarray形式で保存されていることがわかります。
# estimator : 決定木オブジェクト
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)
##### 一部削除 #####
# 決定木構造情報の抽出
# Using those arrays, we can parse the tree structure:
n_nodes = estimator.tree_.node_count # ノードの数(int)
children_left = estimator.tree_.children_left # 各ノードから左(True, 閾値以下)への分岐先ノード番号(list)
children_right = estimator.tree_.children_right # 各ノードから右(False, 閾値超)への分岐先ノード番号(list)
feature = estimator.tree_.feature # 各ノードの分岐に使用される変数の番号(list)
threshold = estimator.tree_.threshold # 分岐の閾値(list)
その他の重要な変数を含めて簡単にまとめると以下のようになります。
leaf(末端ノード)で該当するものが無い場合は-1または-2が格納されているようですね。
属性名 | 内容 | 例 |
---|---|---|
node_count | 総ノード数 | 11 |
feature | 変数の番号 | array([ 2, -2, 3, 2...]) |
threshold | 変数の閾値 | array([2.45,-2.,1.75,4.65,...]) |
children_left | 閾値以下の分岐先ノード番号 | array([ 1, -1, 3, 4,...]) |
children_right | 閾値超の分岐先ノード番号 | array([ 2, -1, 8, 7,...]) |
value | ノードの値 | array([[[50., 50., 50.]],[[50., 0., 0.]],[[ 0., 50., 50.]],...) |
impurity | ノードの不純度 | array([0.667,0.,0.5,0.168,...]) |
n_node_samples | ノードのサンプル数 | array([150, 50, 100, 54,...]) |
例は具体的にirisデータセットの決定木作成結果から抽出してみた結果で、GraphVizを使用して可視化した結果と比較するとわかりやすいかと思います。
無事に分岐条件に関連する情報を抽出できました。
決定木の構造をテキストで可視化
以上で抽出した情報を使用し、上記公式サイトを参考にして決定木の分岐情報をテキスト化してみます。
import numpy as np
def visualizeTreeStructure(decision_tree,feature_names=None):
'''return text decision tree structure information
ref: https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
'''
# retrieve decision tree structure information
txt = ''
n_nodes = decision_tree.tree_.node_count
children_left = decision_tree.tree_.children_left
children_right = decision_tree.tree_.children_right
feature = decision_tree.tree_.feature
threshold = decision_tree.tree_.threshold
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
# define each node is leaf or not
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True
txt += "The binary tree structure has {} nodes and has the following tree structure:".format(n_nodes)
for i in range(n_nodes):
if is_leaves[i]:
txt += "\n{}node={} leaf node.".format(node_depth[i] * "\t", i)
else:
if feature_names is None:
txt+= "\n{}node={} test node: go to node {} if X[{:.0f}] <= {} else to node {}.".format(node_depth[i] * "\t",
i,
children_left[i],
feature[i],
threshold[i],
children_right[i],
)
else:
txt+= "\n{}node={} test node: go to node {} if {} <= {} else to node {}.".format(node_depth[i] * "\t",
i,
children_left[i],
feature_names[feature[i]],
threshold[i],
children_right[i],
)
return txt
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> iris = load_iris()
>>> dtc = DecisionTreeClassifier(max_depth=5,min_samples_leaf=10)
>>> dtc.fit(iris['data'],iris['target'])
>>> print(visualizeTreeStructure(dtc,feature_names=iris['feature_names']))
# 出力結果
The binary tree structure has 11 nodes and has the following tree structure:
node=0 test node: go to node 1 if petal length (cm) <= 2.449999988079071 else to node 2.
node=1 leaf node.
node=2 test node: go to node 3 if petal width (cm) <= 1.75 else to node 8.
node=3 test node: go to node 4 if petal length (cm) <= 4.6499998569488525 else to node 7.
node=4 test node: go to node 5 if petal length (cm) <= 4.450000047683716 else to node 6.
node=5 leaf node.
node=6 leaf node.
node=7 leaf node.
node=8 test node: go to node 9 if sepal length (cm) <= 6.25 else to node 10.
node=9 leaf node.
node=10 leaf node.
GraphVizの出力結果に比べれば決して見やすくはありませんが、無事に同じ構造をテキストで再現することができました。
ノードごとに決定条件のサマリーをデータ化
これまでに決定木オブジェクトから必要な情報は抽出できたものの、データの格納方法が分岐単位になっており、少し使い勝手に欠けます。
例えば、上記例のnode 6の決定条件は、上から辿っていくとpetal length (cm)の条件分岐が3回、petal width (cm)の条件分岐が1回出てきていますが、まとめると最初のpetal length (cm)の条件($\gt 2.45$)は閾値とならず、以下の閾値条件のみで決まっていることがわかります。
4.45 \lt \text{petal length (cm)} \le 4.65\\
\text{petal width (cm)} \le 1.75
そこで、各ノードごとに最終的な決定条件をまとめたデータフレームを作成します。
import numpy as np
import pandas as pd
def extractDecisionInfo(decision_tree,X_train,feature_names=None,only_leaves=False):
'''return dataframe with node info
'''
# extract info from decision_tree
n_nodes = decision_tree.tree_.node_count
children_left = decision_tree.tree_.children_left
children_right = decision_tree.tree_.children_right
feature = decision_tree.tree_.feature
threshold = decision_tree.tree_.threshold
impurity = decision_tree.tree_.impurity
value = decision_tree.tree_.value
n_node_samples = decision_tree.tree_.n_node_samples
# cast X_train as dataframe
df = pd.DataFrame(X_train)
if feature_names is not None:
df.columns = feature_names
# indexes with unique nodes
idx_list = df.assign(
leaf_id = lambda df: decision_tree.apply(df)
)[['leaf_id']].drop_duplicates().index
# test data for unique nodes
X_test = df.loc[idx_list,].to_numpy()
# decision path only for leaves
dp = decision_tree.decision_path(X_test)
# final leaves for each data
leave_id = decision_tree.apply(X_test)
# values for each data
leave_predict = decision_tree.predict(X_test)
# dictionary for leave_id and leave_predict
dict_leaves = {k:v for k,v in zip(leave_id,leave_predict)}
# create decision path information for all nodes
dp_idxlist = [[ini, fin] for ini,fin in zip(dp.indptr[:-1],dp.indptr[1:])]
dict_decisionpath = {}
for idxs in dp_idxlist:
dpindices = dp.indices[idxs[0]:idxs[1]]
for i,node in enumerate(dpindices):
if node not in dict_decisionpath.keys():
dict_decisionpath[node] = dpindices[:i+1]
# initialize number of columns and output dataframe
n_cols = df.shape[-1]
df_thr_all = pd.DataFrame()
# predict for samples
for node, node_index in dict_decisionpath.items():
l_thresh_max = np.ones(n_cols) * np.nan
l_thresh_min = np.ones(n_cols) * np.nan
# decision path info for each node
for i,node_id in enumerate(node_index):
if node == node_id:
continue
if children_left[node_id] == node_index[i+1]: #(X_test[sample_id, feature[node_id]] <= threshold[node_id]):
l_thresh_max[feature[node_id]] = threshold[node_id]
else:
l_thresh_min[feature[node_id]] = threshold[node_id]
# append info to df_thr_all
df_thr_all = df_thr_all.append(
[[(thr_min,thr_max) for thr_max,thr_min in zip(l_thresh_max,l_thresh_min)]
+ [
node,
np.nan if node not in dict_leaves.keys() else dict_leaves[node],
value[node],
impurity[node],
n_node_samples[node]
]
]
)
# rename columns and set index
if feature_names is not None:
df_thr_all.columns = feature_names + ['node','predicted_value','value','impurity','n_node_samples']
else:
df_thr_all.columns = ['X_{}'.format(i) for i in range(n_cols)] + ['node','predicted_value','value','impurity','n_node_samples']
df_thr_all = df_thr_all.set_index('node')
if only_leaves:
df_thr_all = df_thr_all[~df_thr_all['predicted_value'].isnull()]
return df_thr_all.sort_index()
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> iris = load_iris()
>>> dtc = DecisionTreeClassifier(max_depth=5,min_samples_leaf=10)
>>> dtc.fit(iris['data'],iris['target'])
>>> print(extractDecisionInfo(dtc,iris['data'],feature_names=iris['feature_names']))
# 出力結果
sepal length (cm) sepal width (cm) \
node
0 (nan, nan) (nan, nan)
1 (nan, nan) (nan, nan)
2 (nan, nan) (nan, nan)
3 (nan, nan) (nan, nan)
4 (nan, nan) (nan, nan)
5 (nan, nan) (nan, nan)
6 (nan, nan) (nan, nan)
7 (nan, nan) (nan, nan)
8 (nan, nan) (nan, nan)
9 (nan, 6.25) (nan, nan)
10 (6.25, nan) (nan, nan)
petal length (cm) petal width (cm) \
node
0 (nan, nan) (nan, nan)
1 (nan, nan) (nan, 0.800000011920929)
2 (nan, nan) (0.800000011920929, nan)
3 (nan, nan) (0.800000011920929, 1.75)
4 (nan, 4.6499998569488525) (0.800000011920929, 1.75)
5 (nan, 4.450000047683716) (0.800000011920929, 1.75)
6 (4.450000047683716, 4.6499998569488525) (0.800000011920929, 1.75)
7 (4.6499998569488525, nan) (0.800000011920929, 1.75)
8 (nan, nan) (1.75, nan)
9 (nan, nan) (1.75, nan)
10 (nan, nan) (1.75, nan)
predicted_value value impurity n_node_samples
node
0 NaN [[50.0, 50.0, 50.0]] 0.666667 150
1 0.0 [[50.0, 0.0, 0.0]] 0.000000 50
2 NaN [[0.0, 50.0, 50.0]] 0.500000 100
3 NaN [[0.0, 49.0, 5.0]] 0.168038 54
4 NaN [[0.0, 39.0, 1.0]] 0.048750 40
5 1.0 [[0.0, 29.0, 0.0]] 0.000000 29
6 1.0 [[0.0, 10.0, 1.0]] 0.165289 11
7 1.0 [[0.0, 10.0, 4.0]] 0.408163 14
8 NaN [[0.0, 1.0, 45.0]] 0.042533 46
9 2.0 [[0.0, 1.0, 10.0]] 0.165289 11
10 2.0 [[0.0, 0.0, 35.0]] 0.000000 35
以上で各変数名のカラムに(最小値, 最大値)の形のSet型で分岐の決定条件が格納されました。
具体的にnode 6の最終的な決定条件と見比べてみると、
4.45 \lt \text{petal length (cm)} \le 4.65\\
\text{petal width (cm)} \le 1.75
であったのに対し、出力されたデータフレームの該当箇所を抽出すると、
node | petal length (cm) | petal width (cm) |
---|---|---|
6 | (4.45, 4.65) | (nan, 1.75) |
となっており、無事に閾値条件が表現されています。
以上で決定木のノードごとに分岐条件サマリーをデータとして格納することができました。
全体を実行したipynbはGitHubへアップロードしてあるのでよければどうぞ。