LoginSignup
18
20

More than 3 years have passed since last update.

Scikit-learnの決定木からノードの分岐条件を抽出、データとして格納する

Last updated at Posted at 2019-07-13

はじめに

Scikit-learnの決定木の分岐情報は一般的にGraphVizを使用した可視化によって理解されます。

しかし、この方法だと人間が直感的に理解するには良いものの、
- データとして決定パスの情報を利用できない
- GraphVizとの連携が少し煩雑なためサクッと可視化できない
といった問題をはらんでいます。

前者は、大量の決定木の分岐条件をデータとして扱いたい場合や、最終的な分岐条件をビジネス的な理由で全てテキスト化せねばならない場合などに人力では対処できずに苦しむ結果になりますし、
後者はGraphVizがなかなか入らない初心者や、新しいソフトウェアを入れづらい厄介な分析環境を使用せねばならない場合に苦しむことになります。

そこで、できればGraphVizを使用した可視化に頼らず、直接データとして分岐情報を扱う方法を知りたいと思っていました。

そこで今回はScikit-learnの決定木オブジェクト自体から分岐情報を抽出し、以下のように各ノードの最終的な分岐条件をデータとして格納する方法を模索します。

node 条件1 条件2
1 $x\le 1$ -
2 $1\lt x \le3$ $4\lt x$
3 $1\lt x \le3$ $x \le 4$
4 $3\lt x$ -

全体の実行ipynbはこちら

先人たちの軌跡

  • Understanding the decision tree structure

    • 公式のScikit-learnサイトにおいて決定木の構造をテキストで可視化する方法を紹介しています。本稿ではこちらを主に参考にして分岐条件の抽出を行います。
  • [Python]Graphviz不要の決定木可視化ライブラリdtreepltをつくった

    • 直接試してはいませんが、GraphVizを使用せずにMatplotlibで類似の画像を描画する試みも存在します。テキストでの可視化は直感的ではないので、理解のためには画像での描画は必須ですね。
  • dtreeviz : Decision Tree Visualization

    • こちらは論点がずれますが、GraphVizを使用しつつ、ノードにおける分岐条件をよりリッチにした描画を可能にしたものです。ただ、情報がリッチになりすぎる分可読性は下がるので使いどころが難しい気もします。

決定木オブジェクトからの分岐条件の抽出

それでは具体的な内容へと移ります。

上でも紹介しましたが、Scikit-learnの公式サイトを漁ってみると、"Understanding the decision tree structure"という解説サイトがあります。

こちらによると、決定木オブジェクトにおける分岐情報は決定木オブジェクトの上位階層tree_におけるいくつかの属性にノード数サイズのarray形式で保存されていることがわかります。

plot_unveil_tree_structure.pyより抜粋、一部加筆
# estimator : 決定木オブジェクト
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)
#####   一部削除   #####

# 決定木構造情報の抽出
# Using those arrays, we can parse the tree structure:
n_nodes = estimator.tree_.node_count # ノードの数(int)
children_left = estimator.tree_.children_left # 各ノードから左(True, 閾値以下)への分岐先ノード番号(list)
children_right = estimator.tree_.children_right # 各ノードから右(False, 閾値超)への分岐先ノード番号(list)
feature = estimator.tree_.feature # 各ノードの分岐に使用される変数の番号(list)
threshold = estimator.tree_.threshold # 分岐の閾値(list)

その他の重要な変数を含めて簡単にまとめると以下のようになります。
leaf(末端ノード)で該当するものが無い場合は-1または-2が格納されているようですね。

属性名 内容
node_count 総ノード数 11
feature 変数の番号 array([ 2, -2, 3, 2...])
threshold 変数の閾値 array([2.45,-2.,1.75,4.65,...])
children_left 閾値以下の分岐先ノード番号 array([ 1, -1, 3, 4,...])
children_right 閾値超の分岐先ノード番号 array([ 2, -1, 8, 7,...])
value ノードの値 array([[[50., 50., 50.]],[[50., 0., 0.]],[[ 0., 50., 50.]],...)
impurity ノードの不純度 array([0.667,0.,0.5,0.168,...])
n_node_samples ノードのサンプル数 array([150, 50, 100, 54,...])

例は具体的にirisデータセットの決定木作成結果から抽出してみた結果で、GraphVizを使用して可視化した結果と比較するとわかりやすいかと思います。
図解_GraphViz.jpg

無事に分岐条件に関連する情報を抽出できました。

決定木の構造をテキストで可視化

以上で抽出した情報を使用し、上記公式サイトを参考にして決定木の分岐情報をテキスト化してみます。

visualizeTreeStructure
import numpy as np

def visualizeTreeStructure(decision_tree,feature_names=None):
    '''return text decision tree structure information
       ref: https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
    '''

    # retrieve decision tree structure information
    txt = ''
    n_nodes = decision_tree.tree_.node_count
    children_left = decision_tree.tree_.children_left
    children_right = decision_tree.tree_.children_right
    feature = decision_tree.tree_.feature
    threshold = decision_tree.tree_.threshold
    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)

    # define each node is leaf or not
    stack = [(0, -1)]  # seed is the root node id and its parent depth
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()
        node_depth[node_id] = parent_depth + 1

        # If we have a test node
        if (children_left[node_id] != children_right[node_id]):
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True

    txt += "The binary tree structure has {} nodes and has the following tree structure:".format(n_nodes) 
    for i in range(n_nodes):
        if is_leaves[i]:
            txt += "\n{}node={} leaf node.".format(node_depth[i] * "\t", i)
        else:
            if feature_names is None:
                txt+= "\n{}node={} test node: go to node {} if X[{:.0f}] <= {} else to node {}.".format(node_depth[i] * "\t",
                         i,
                         children_left[i],
                         feature[i],
                         threshold[i],
                         children_right[i],
                         )
            else:
                txt+= "\n{}node={} test node: go to node {} if {} <= {} else to node {}.".format(node_depth[i] * "\t",
                         i,
                         children_left[i],
                         feature_names[feature[i]],
                         threshold[i],
                         children_right[i],
                         )
    return txt
出力結果
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> iris = load_iris()
>>> dtc = DecisionTreeClassifier(max_depth=5,min_samples_leaf=10)
>>> dtc.fit(iris['data'],iris['target'])
>>> print(visualizeTreeStructure(dtc,feature_names=iris['feature_names']))
# 出力結果
The binary tree structure has 11 nodes and has the following tree structure:
node=0 test node: go to node 1 if petal length (cm) <= 2.449999988079071 else to node 2.
    node=1 leaf node.
    node=2 test node: go to node 3 if petal width (cm) <= 1.75 else to node 8.
        node=3 test node: go to node 4 if petal length (cm) <= 4.6499998569488525 else to node 7.
            node=4 test node: go to node 5 if petal length (cm) <= 4.450000047683716 else to node 6.
                node=5 leaf node.
                node=6 leaf node.
            node=7 leaf node.
        node=8 test node: go to node 9 if sepal length (cm) <= 6.25 else to node 10.
            node=9 leaf node.
            node=10 leaf node.

GraphVizの出力結果に比べれば決して見やすくはありませんが、無事に同じ構造をテキストで再現することができました。

ノードごとに決定条件のサマリーをデータ化

これまでに決定木オブジェクトから必要な情報は抽出できたものの、データの格納方法が分岐単位になっており、少し使い勝手に欠けます。

例えば、上記例のnode 6の決定条件は、上から辿っていくとpetal length (cm)の条件分岐が3回、petal width (cm)の条件分岐が1回出てきていますが、まとめると最初のpetal length (cm)の条件($\gt 2.45$)は閾値とならず、以下の閾値条件のみで決まっていることがわかります。

4.45 \lt \text{petal length (cm)} \le 4.65\\
\text{petal width (cm)} \le 1.75

そこで、各ノードごとに最終的な決定条件をまとめたデータフレームを作成します。

extractDecisionInfo
import numpy as np
import pandas as pd

def extractDecisionInfo(decision_tree,X_train,feature_names=None,only_leaves=False):
    '''return dataframe with node info
    '''

    # extract info from decision_tree
    n_nodes = decision_tree.tree_.node_count
    children_left = decision_tree.tree_.children_left
    children_right = decision_tree.tree_.children_right
    feature = decision_tree.tree_.feature
    threshold = decision_tree.tree_.threshold
    impurity = decision_tree.tree_.impurity
    value = decision_tree.tree_.value
    n_node_samples = decision_tree.tree_.n_node_samples

    # cast X_train as dataframe
    df = pd.DataFrame(X_train)
    if feature_names is not None:
        df.columns = feature_names

    # indexes with unique nodes
    idx_list = df.assign(
        leaf_id = lambda df: decision_tree.apply(df)
    )[['leaf_id']].drop_duplicates().index

    # test data for unique nodes
    X_test = df.loc[idx_list,].to_numpy()
    # decision path only for leaves
    dp = decision_tree.decision_path(X_test)
    # final leaves for each data
    leave_id = decision_tree.apply(X_test)
    # values for each data
    leave_predict = decision_tree.predict(X_test)
    # dictionary for leave_id and leave_predict
    dict_leaves = {k:v for k,v in zip(leave_id,leave_predict)}

    # create decision path information for all nodes
    dp_idxlist = [[ini, fin] for ini,fin in zip(dp.indptr[:-1],dp.indptr[1:])]
    dict_decisionpath = {}
    for idxs in dp_idxlist:
        dpindices = dp.indices[idxs[0]:idxs[1]]
        for i,node in enumerate(dpindices):
            if node not in dict_decisionpath.keys():
                dict_decisionpath[node] = dpindices[:i+1]

    # initialize number of columns and output dataframe
    n_cols = df.shape[-1]
    df_thr_all = pd.DataFrame()

    # predict for samples

    for node, node_index in dict_decisionpath.items():
        l_thresh_max = np.ones(n_cols) * np.nan
        l_thresh_min = np.ones(n_cols) * np.nan

        # decision path info for each node
        for i,node_id in enumerate(node_index):
            if node == node_id:
                continue

            if children_left[node_id] == node_index[i+1]: #(X_test[sample_id, feature[node_id]] <= threshold[node_id]):
                l_thresh_max[feature[node_id]] = threshold[node_id]
            else:
                l_thresh_min[feature[node_id]] = threshold[node_id]

        # append info to df_thr_all
        df_thr_all = df_thr_all.append(
            [[(thr_min,thr_max) for thr_max,thr_min in zip(l_thresh_max,l_thresh_min)]
             + [
                 node,
                 np.nan if node not in dict_leaves.keys() else dict_leaves[node],
                 value[node],
                 impurity[node],
                 n_node_samples[node]
               ]
            ]
        )

    # rename columns and set index
    if feature_names is not None:
        df_thr_all.columns = feature_names + ['node','predicted_value','value','impurity','n_node_samples']
    else:
        df_thr_all.columns = ['X_{}'.format(i) for i in range(n_cols)] + ['node','predicted_value','value','impurity','n_node_samples']
    df_thr_all = df_thr_all.set_index('node')

    if only_leaves:
        df_thr_all = df_thr_all[~df_thr_all['predicted_value'].isnull()]

    return df_thr_all.sort_index()
出力結果
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> iris = load_iris()
>>> dtc = DecisionTreeClassifier(max_depth=5,min_samples_leaf=10)
>>> dtc.fit(iris['data'],iris['target'])
>>> print(extractDecisionInfo(dtc,iris['data'],feature_names=iris['feature_names']))
# 出力結果
     sepal length (cm) sepal width (cm)  \
node                                      
0           (nan, nan)       (nan, nan)   
1           (nan, nan)       (nan, nan)   
2           (nan, nan)       (nan, nan)   
3           (nan, nan)       (nan, nan)   
4           (nan, nan)       (nan, nan)   
5           (nan, nan)       (nan, nan)   
6           (nan, nan)       (nan, nan)   
7           (nan, nan)       (nan, nan)   
8           (nan, nan)       (nan, nan)   
9          (nan, 6.25)       (nan, nan)   
10         (6.25, nan)       (nan, nan)   

                            petal length (cm)           petal width (cm)  \
node                                                                       
0                                  (nan, nan)                 (nan, nan)   
1                                  (nan, nan)   (nan, 0.800000011920929)   
2                                  (nan, nan)   (0.800000011920929, nan)   
3                                  (nan, nan)  (0.800000011920929, 1.75)   
4                   (nan, 4.6499998569488525)  (0.800000011920929, 1.75)   
5                    (nan, 4.450000047683716)  (0.800000011920929, 1.75)   
6     (4.450000047683716, 4.6499998569488525)  (0.800000011920929, 1.75)   
7                   (4.6499998569488525, nan)  (0.800000011920929, 1.75)   
8                                  (nan, nan)                (1.75, nan)   
9                                  (nan, nan)                (1.75, nan)   
10                                 (nan, nan)                (1.75, nan)   

      predicted_value                 value  impurity  n_node_samples  
node                                                                   
0                 NaN  [[50.0, 50.0, 50.0]]  0.666667             150  
1                 0.0    [[50.0, 0.0, 0.0]]  0.000000              50  
2                 NaN   [[0.0, 50.0, 50.0]]  0.500000             100  
3                 NaN    [[0.0, 49.0, 5.0]]  0.168038              54  
4                 NaN    [[0.0, 39.0, 1.0]]  0.048750              40  
5                 1.0    [[0.0, 29.0, 0.0]]  0.000000              29  
6                 1.0    [[0.0, 10.0, 1.0]]  0.165289              11  
7                 1.0    [[0.0, 10.0, 4.0]]  0.408163              14  
8                 NaN    [[0.0, 1.0, 45.0]]  0.042533              46  
9                 2.0    [[0.0, 1.0, 10.0]]  0.165289              11  
10                2.0    [[0.0, 0.0, 35.0]]  0.000000              35  

以上で各変数名のカラムに(最小値, 最大値)の形のSet型で分岐の決定条件が格納されました。

具体的にnode 6の最終的な決定条件と見比べてみると、

4.45 \lt \text{petal length (cm)} \le 4.65\\
\text{petal width (cm)} \le 1.75

であったのに対し、出力されたデータフレームの該当箇所を抽出すると、

node petal length (cm) petal width (cm)
6 (4.45, 4.65) (nan, 1.75)

となっており、無事に閾値条件が表現されています。

以上で決定木のノードごとに分岐条件サマリーをデータとして格納することができました。

全体を実行したipynbはGitHubへアップロードしてあるのでよければどうぞ。

18
20
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
20