0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

sklearnのDecisionTreeClassifierの結果の理解

Last updated at Posted at 2022-04-10

はじめに

今更感満載ですが、sklearn.tree.DecisionTreeClassifier(決定木分析)の結果を、Graphvizを使わずに、結果を読み解いてます。
「Graphvizインストールしないと見れない~」「インストールが~パスが~めんどくさい~~!」とか言うなよと。

偉そうなことを言いつつ、公式HPのUnderstanding the decision tree structureの焼き直しなので、正確な情報が好きな方はそちらへどうぞ。

決定木実行

とりあえずデータを読んで決定木分析するまで。

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)  # (1)
clf.fit(X_train, y_train)   # (2)

(1)で作成されたオブジェクトが、(2)の実施によって更新され、結果は clf の内部変数に書かれています。本記事のタイトルの「結果の理解」というのは、この内部変数の理解のことです。

irisデータについて

結果の理解の前に、今回の入力データであるirisデータについて。sklearn.datasets.load_iris

項目 説明
Classes 3 結果(iris)の種類数
Samples par class 50 1種類ごとの行数
Samples total 150 全部の行数(3x50=150)
Dimensionality 4 説明変数の種類数
Features real, positive 説明変数の型(実数型でプラス値しかない)

実際のデータはこんな感じ。

print(iris.data[:5])
# [[5.1 3.5 1.4 0.2]
# [4.9 3.  1.4 0.2]
# [4.7 3.2 1.3 0.2]
# [4.6 3.1 1.5 0.2]
# [5.  3.6 1.4 0.2]]

print(iris.target[:5])
# [0 0 0 0 0]

irisデータはtarget(結果)の0,1,2でソートされてるので、先頭の5行をとると全部ゼロですが、0,1,2が入ってます。
ついでに中身の理解のために書いておくと、意味は下記です。

  • 説明変数
    • sepal length in cm(がくの長さ)
    • sepal width in cm(がくの幅)
    • petal length in cm(花びらの長さ)
    • petal width in cm(花びらの幅)
  • 分類(こういう種類のIris・アヤメ)
    • 0:Iris-Setosa
    • 1:Iris-Versicolour
    • 2:Iris-Virginica

今回の話には何にも関係ありませんが、異常値に気づいたりするためにはこういう理解は必要です。

結果の理解

本題の clf.fit(X_train, y_train) 実行後の clf の読み方。

DecisionTreeClassifierクラス(clfオブジェクト)のプロパティ

clfの中身を見ていきます。sklearn.tree.DecisionTreeClassifier

内容は大きく2つに分類できて、1つは実行条件、もう1つは結果です。clf のプロパティを見ていくのですが、結果の変数名は末尾に_(アンダースコア)がついていて、実行条件はついていません。例えば、clf.max_depthは、実行条件の最大深さ。clf.n_features_in_は、入力ファイルを読んだ結果の説明変数の種類数です。
このclfから得られる情報は、全体的な情報と結果です。

for k,v in clf_model.__dict__.items():
    print(f'{k:20}: {v}')
# criterion           : gini
# splitter            : best
# max_depth           : 3
# min_samples_split   : 2
# min_samples_leaf    : 1
# min_weight_fraction_leaf: 0.0
# max_features        : None
# max_leaf_nodes      : None
# random_state        : None
# min_impurity_decrease: 0.0
# class_weight        : None
# ccp_alpha           : 0.0
# n_features_in_      : 4
# n_outputs_          : 1
# classes_            : [0 1 2]
# n_classes_          : 3
# max_features_       : 4
# tree_               : <sklearn.tree._tree.Tree object at 0x0000025340BF7AB0>

次は、ノード1つずつの結果を見たいので、tree_を見ていきます。

sklearn.tree._tree.Treeクラス(clf.tree_オブジェクト)のプロパティ

このクラスは、tree_.__dict__が使えないので、Understanding the decision tree structure を読んで、プロパティを直打ちしてみます。

print('--- tree info ---')
print(tree_obj)
print(f'{"node_count":20}: {tree_obj.node_count}')
print(f'{"max_depth":20}: {tree_obj.max_depth}')
print(f'{"children_left":20}: {tree_obj.children_left}')
print(f'{"children_right":20}: {tree_obj.children_right}')
print(f'{"feature":20}: {tree_obj.feature}')
print(f'{"threshold":20}: {tree_obj.threshold}')
print(f'{"n_node_samples":20}: {tree_obj.n_node_samples}')
print(f'{"value":20}: {tree_obj.value}')  # Understaindig ~ に記述なし
print(f'{"impurity":20}: {tree_obj.impurity}')
# <sklearn.tree._tree.Tree object at 0x0000025340BF7AB0>
# node_count          : 9
# max_depth           : 3
# children_left       : [ 1 -1  3  4 -1 -1  7 -1 -1]
# children_right      : [ 2 -1  6  5 -1 -1  8 -1 -1]
# feature             : [ 3 -2  3  2 -2 -2  2 -2 -2]
# threshold           : [ 0.80000001 -2. 1.75 4.95000005 -2. -2.  4.85000014 -2. -2.]
# n_node_samples      : [112  37  75  39  35   4  36   3  33]
# value               : [[[37. 37. 38.]] [[37.  0.  0.]] [[ 0. 37. 38.]] [[ 0. 36.  3.]] [[ 0. 34.  1.]] [[ 0.  2.  2.]] [[ 0.  1. 35.]] [[ 0.  1.  2.]] [[ 0.  0. 33.]]]
# impurity            : [0.66661352 0. 0.49991111 0.14201183 0.0555102 0.5 0.05401235 0.44444444 0. ]

ノードは9個で、最大深さが3,それ以外は謎の配列。ということで、ここからようやく本題です。
ノードの数が9個で、配列の要素はすべて9要素、ということから想像できると思いますが、ノードに対応した情報が配列になっています。つまりノード i の情報は、~~[i] で得られるという形。

ではそのプロパティを1つずつ見ます。

まず、決定木の一番上のノードは、i=0と決まっています。これは決め。以降はi=0の数値を見ていきます。

tree_.children_left[0]=1は、i=0の左の子はi=1の要素で、tree_.children_right[0]=2は、i=0の右の子はi=2の要素であることを指しています。つまり、左の子を掘り下げたいなら、i=1で同じことをすればよいということ。

i=0に話を戻して、feature[0]=3threshold[0]=0.80000001は、説明変数の列index=3(0開始)が、0.80000001以下の場合は、左の子、それ以外は右の子という意味です。features[0]=30.80000001以下ということは「4番目の変数であるpetal widthが0.8cm以下」という意味。

n_node_samples[0]=112は、このノードに来るサンプル数が112個ということです。irisの元データは150個あったのにi=0で112個というのは、train_test_spritで学習データがデフォルト値の75%、150*0.75=112.5→112個になっています。

value[0]=[[37. 37. 38.]]は、サンプル112個が、3クラスのそれぞれいくつかを示します。これはUnderstanding the decision tree structureに書いてなかったので、"~\Lib\site-packages\sklearn\tree_tree.pxd"のcdef class Treeから探りました。

impurity[0]=0.66661352は、不純度。0.0~1.0の数値で、この数字が1.0に近いとまだそのノードではばらけてる、0.0に近いと1種類に絞れているとみなせます。いくつだったらいいのかというのはケースバイケースです。(例えば、病気の判定だったら厳しく必要だが、ECサイトのレコメンドだったら緩くてよい)
計算方法は、3クラスのそれぞれの割合を2乗したものを足した数値を、1から引くという計算。value[0]=[[37. 37. 38.]]を使って計算してみると

1 - ( (37/112)**2 + (37/112)**2 + (38/112)**2 )
= 1 - ( 0.33...**2 + 0.30...**2 + 0.37...**2 )
= 1 - 0.333386...
= 0.66661352...

となり、手計算でもimpurity=0.66661352が出せました。

Treeのビューアー

「結果の理解」という意味ではここまでで終了です。以降は、「結果を理解」したうえで、それを可視化する話。

何度も書きますがUnderstanding the decision tree structureのほぼ丸コピーです。コメントを日本語にしたことと、valueを表示している点のみが相違点です。

def print_tree_structure(tree_obj):
    # Understanding the decision tree structure
    # https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py
    print('--- print_tree_structure ---')
    n_nodes = tree_obj.node_count
    children_left = tree_obj.children_left
    children_right = tree_obj.children_right
    feature = tree_obj.feature
    threshold = tree_obj.threshold
    tree_values = tree_obj.value

    # node_depth, is_leaves, stackを調べる
    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, 0)]  # 最初に、ノードID=0、深さ=0を、1要素だけ登録しておく
    while len(stack) > 0:
        # pop()を使うことで、1つのノードは1回しか取り出さないことを保証している
        node_id, depth = stack.pop()
        node_depth[node_id] = depth

        # 左右の子が違っていたら、split_nodeと判断する
        is_split_node = children_left[node_id] != children_right[node_id]
        # split_node の場合は、stackに追加することで、次のループで使用される
        if is_split_node:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            is_leaves[node_id] = True

    # print開始
    print(
        "The binary tree structure has {n} nodes and has "
        "the following tree structure:\n".format(n=n_nodes)
    )
    for i in range(n_nodes):
        if is_leaves[i]:
            print(
                "{space}node={node} is a leaf node. {value}".format(
                    space=node_depth[i] * "\t",
                    node=i,
                    value=tree_values[i][0],
                )
            )
        else:
            print(
                "{space}node={node} is a split node: "
                "go to node {left} if X[:, {feature}] <= {threshold} "
                "else to node {right}. {value}".format(
                    space=node_depth[i] * "\t",
                    node=i,
                    left=children_left[i],
                    feature=feature[i],
                    threshold=threshold[i],
                    right=children_right[i],
                    value=tree_values[i][0],
                )
            )
# --- print_tree_structure ---
# The binary tree structure has 9 nodes and has the following tree structure:
# 
# node=0 is a split node: go to node 1 if X[:, 3] <= 0.800000011920929 else to node 2. [37. 37. 38.]
#         node=1 is a leaf node. [37.  0.  0.]
#         node=2 is a split node: go to node 3 if X[:, 3] <= 1.75 else to node 6. [ 0. 37. 38.]
#                 node=3 is a split node: go to node 4 if X[:, 2] <= 4.950000047683716 else to node 5. [ 0. 36.  3.]
#                         node=4 is a leaf node. [ 0. 34.  1.]
#                         node=5 is a leaf node. [0. 2. 2.]
#                 node=6 is a split node: go to node 7 if X[:, 0] <= 5.950000047683716 else to node 8. [ 0.  1. 35.]
#                         node=7 is a leaf node. [0. 1. 2.]
#                         node=8 is a leaf node. [ 0.  0. 33.]

これで、Graphvizとかを使用せずとも、同じ情報が見られます。

おわりに

Graphvizとかdtreevizが楽なのはわかるけど、exeをインストールしないといけない気持ち悪さとか、それを使わないと見ることができないっていう技術力の低さなんなのという憤り(?)から、ちょっと調べたらすぐできるよ!という内容の記事でした。

個人的には、この情報を使って、CSVをポイっと入れると決定木分析ができる可視化ツールを作る予定です。pythonを書ける人は(Graphvizを使わず)自分でやれ、書けないひとは誰かに頼れ、というすみ分けをしちゃいたいという気持ち。

あとついでにひとこと。sklearnでもなんでもいいから、CHAIDの決定木を実装してほしいなぁー!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?