5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

決定木による分類基準を集計する

Last updated at Posted at 2018-08-29

決定木で分類できるのはいいんだけど、どういう基準で分類していることが多いのか整理したい。そこで、決定木による分類基準を概観する方法を検討しました。

参考にさせていただいたのは scikit-learnの決定木系モデルを視覚化する方法

決定木の詳細を見るのは Graphviz (Graph Visualization Software) で視覚化するといいらしいですが、そこに出力された木を一個一個眺めるのってしんどいじゃないですか。なのでその結果を集計して概観したいなと。

iris のデータをインポート

%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn import tree

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)

よく知られている、こういう感じのデータを使います。

df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
5 5.4 3.9 1.7 0.4 setosa
6 4.6 3.4 1.4 0.3 setosa
7 5.0 3.4 1.5 0.2 setosa
8 4.4 2.9 1.4 0.2 setosa
9 4.9 3.1 1.5 0.1 setosa
10 5.4 3.7 1.5 0.2 setosa
11 4.8 3.4 1.6 0.2 setosa
12 4.8 3.0 1.4 0.1 setosa
13 4.3 3.0 1.1 0.1 setosa
14 5.8 4.0 1.2 0.2 setosa
15 5.7 4.4 1.5 0.4 setosa
16 5.4 3.9 1.3 0.4 setosa
17 5.1 3.5 1.4 0.3 setosa
18 5.7 3.8 1.7 0.3 setosa
19 5.1 3.8 1.5 0.3 setosa
20 5.4 3.4 1.7 0.2 setosa
21 5.1 3.7 1.5 0.4 setosa
22 4.6 3.6 1.0 0.2 setosa
23 5.1 3.3 1.7 0.5 setosa
24 4.8 3.4 1.9 0.2 setosa
25 5.0 3.0 1.6 0.2 setosa
26 5.0 3.4 1.6 0.4 setosa
27 5.2 3.5 1.5 0.2 setosa
28 5.2 3.4 1.4 0.2 setosa
29 4.7 3.2 1.6 0.2 setosa
... ... ... ... ... ...

150 rows × 5 columns

scatter plot による概観

まずは scatter plot による概観をしてみましょうか。異なる species は異なる色で彩色しました。

cmap = plt.get_cmap('coolwarm')
colors = [cmap((c)/ 3) for c in [list(iris.target_names).index(name) for name in df['species'].tolist()]]
pd.plotting.scatter_matrix(df.dropna(axis=1)[df.columns[:]], figsize=(8, 8), color=colors) 
plt.show()

output_2_0.png

決定木による分類と、分類基準の集計

さて、ここからが本番です。

  • 決定木による分類を10回繰り返しました。
  • 1回の計算あたりの木の本数は100本にしました。
  • 特徴の重要性(feature importance)を計算しました。
  • 分類基準の分布(criteria_distribution)を集計しました。

ここで、分類基準の取得方法を説明すると

  • tree.export_graphviz(estimator) というメソッドを使うと "tree.dot" という名のテキストファイルが出力され、決定木の構造が記述される。
  • その "tree.dot" のファイルを読み込んで、分類基準の部分だけ抜き出して criteria_distribution に集計する。
  • サンプルサイズの大きなノードを分割したような分類基準は大きめにカウントする(ここでは、サンプルサイズの分だけカウントするようにする)。

という流れです。

features = df.columns[:4]
label = df['species']

criteria_distribution = {}
accumulated_feature_importances = np.zeros(len(df.columns) - 1)

for times_of_learning in range(10):
    df_train, df_test, label_train, label_test = train_test_split(df[features], label)
    clf = RandomForestClassifier(n_estimators=100)
    clf.fit(df_train, label_train)
    accumulated_feature_importances += clf.feature_importances_

    for estimator in clf.estimators_:
        tree.export_graphviz(estimator)
        for line in open("tree.dot"):
            matched = re.search('X\[.+?\n', line)
            if matched:
                equation = matched.group(0).split("\\n")[0]
                numerals = re.findall('[0-9\.]+', line)
                if numerals:
                    #print(equation, numerals)
                    if int(numerals[1]) not in criteria_distribution.keys():
                        criteria_distribution[int(numerals[1])] = []
                    for num_of_samples in range(int(numerals[-1])):
                        criteria_distribution[int(numerals[1])].append(float(numerals[2]))

特徴の重要性の大きい順に、分類基準の分布とデータの分布を概観する

criteria が、分類基準の分布です。重要性の大きい特徴が、3つのクラスを比較的上手に分類できていることが見て取れます。

ranking = np.argsort(accumulated_feature_importances)[::-1]
for rank, feature in enumerate(ranking):
    print("feature #" + str(rank + 1) + " = " + df.columns[feature])
    fig, axes = plt.subplots(nrows = len(iris.target_names) + 1, ncols=1, figsize=(8,8))
    print("Accumulated Feature importance = " + str(accumulated_feature_importances[feature]))
    x_max = max(list(criteria_distribution[feature]) + list(df[df.columns[feature]]))
    x_min = min(list(criteria_distribution[feature]) + list(df[df.columns[feature]]))
    axes[0].hist(criteria_distribution[feature], bins=20, label="criteria")
    axes[0].legend()
    axes[0].set_xlim([x_min, x_max])
    for target, target_name in enumerate(iris.target_names):
        axes[target + 1] = plt.subplot(len(iris.target_names) + 1, 1, target + 2)
        axes[target + 1].hist(list(df[df[df.columns[-1]] == target_name][df.columns[feature]]), bins=10,
                             label=target_name, alpha=0.5)
        axes[target + 1].legend()
        axes[target + 1].set_xlim([x_min, x_max])
    plt.xlabel(df.columns[feature])
    plt.show()

feature #1 = petal width (cm)
Accumulated Feature importance = 4.45000008496

output_4_1.png

feature #2 = petal length (cm)
Accumulated Feature importance = 4.30410137014

output_4_3.png

feature #3 = sepal length (cm)
Accumulated Feature importance = 0.976500917131

output_4_5.png

feature #4 = sepal width (cm)
Accumulated Feature importance = 0.269397627763

output_4_7.png


5
4
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?