決定木で分類できるのはいいんだけど、どういう基準で分類していることが多いのか整理したい。そこで、決定木による分類基準を概観する方法を検討しました。
参考にさせていただいたのは scikit-learnの決定木系モデルを視覚化する方法
決定木の詳細を見るのは Graphviz (Graph Visualization Software) で視覚化するといいらしいですが、そこに出力された木を一個一個眺めるのってしんどいじゃないですか。なのでその結果を集計して概観したいなと。
iris のデータをインポート
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn import tree
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
よく知られている、こういう感じのデータを使います。
df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
5 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
6 | 4.6 | 3.4 | 1.4 | 0.3 | setosa |
7 | 5.0 | 3.4 | 1.5 | 0.2 | setosa |
8 | 4.4 | 2.9 | 1.4 | 0.2 | setosa |
9 | 4.9 | 3.1 | 1.5 | 0.1 | setosa |
10 | 5.4 | 3.7 | 1.5 | 0.2 | setosa |
11 | 4.8 | 3.4 | 1.6 | 0.2 | setosa |
12 | 4.8 | 3.0 | 1.4 | 0.1 | setosa |
13 | 4.3 | 3.0 | 1.1 | 0.1 | setosa |
14 | 5.8 | 4.0 | 1.2 | 0.2 | setosa |
15 | 5.7 | 4.4 | 1.5 | 0.4 | setosa |
16 | 5.4 | 3.9 | 1.3 | 0.4 | setosa |
17 | 5.1 | 3.5 | 1.4 | 0.3 | setosa |
18 | 5.7 | 3.8 | 1.7 | 0.3 | setosa |
19 | 5.1 | 3.8 | 1.5 | 0.3 | setosa |
20 | 5.4 | 3.4 | 1.7 | 0.2 | setosa |
21 | 5.1 | 3.7 | 1.5 | 0.4 | setosa |
22 | 4.6 | 3.6 | 1.0 | 0.2 | setosa |
23 | 5.1 | 3.3 | 1.7 | 0.5 | setosa |
24 | 4.8 | 3.4 | 1.9 | 0.2 | setosa |
25 | 5.0 | 3.0 | 1.6 | 0.2 | setosa |
26 | 5.0 | 3.4 | 1.6 | 0.4 | setosa |
27 | 5.2 | 3.5 | 1.5 | 0.2 | setosa |
28 | 5.2 | 3.4 | 1.4 | 0.2 | setosa |
29 | 4.7 | 3.2 | 1.6 | 0.2 | setosa |
... | ... | ... | ... | ... | ... |
150 rows × 5 columns
scatter plot による概観
まずは scatter plot による概観をしてみましょうか。異なる species は異なる色で彩色しました。
cmap = plt.get_cmap('coolwarm')
colors = [cmap((c)/ 3) for c in [list(iris.target_names).index(name) for name in df['species'].tolist()]]
pd.plotting.scatter_matrix(df.dropna(axis=1)[df.columns[:]], figsize=(8, 8), color=colors)
plt.show()
決定木による分類と、分類基準の集計
さて、ここからが本番です。
- 決定木による分類を10回繰り返しました。
- 1回の計算あたりの木の本数は100本にしました。
- 特徴の重要性(feature importance)を計算しました。
- 分類基準の分布(criteria_distribution)を集計しました。
ここで、分類基準の取得方法を説明すると
- tree.export_graphviz(estimator) というメソッドを使うと "tree.dot" という名のテキストファイルが出力され、決定木の構造が記述される。
- その "tree.dot" のファイルを読み込んで、分類基準の部分だけ抜き出して criteria_distribution に集計する。
- サンプルサイズの大きなノードを分割したような分類基準は大きめにカウントする(ここでは、サンプルサイズの分だけカウントするようにする)。
という流れです。
features = df.columns[:4]
label = df['species']
criteria_distribution = {}
accumulated_feature_importances = np.zeros(len(df.columns) - 1)
for times_of_learning in range(10):
df_train, df_test, label_train, label_test = train_test_split(df[features], label)
clf = RandomForestClassifier(n_estimators=100)
clf.fit(df_train, label_train)
accumulated_feature_importances += clf.feature_importances_
for estimator in clf.estimators_:
tree.export_graphviz(estimator)
for line in open("tree.dot"):
matched = re.search('X\[.+?\n', line)
if matched:
equation = matched.group(0).split("\\n")[0]
numerals = re.findall('[0-9\.]+', line)
if numerals:
#print(equation, numerals)
if int(numerals[1]) not in criteria_distribution.keys():
criteria_distribution[int(numerals[1])] = []
for num_of_samples in range(int(numerals[-1])):
criteria_distribution[int(numerals[1])].append(float(numerals[2]))
特徴の重要性の大きい順に、分類基準の分布とデータの分布を概観する
criteria が、分類基準の分布です。重要性の大きい特徴が、3つのクラスを比較的上手に分類できていることが見て取れます。
ranking = np.argsort(accumulated_feature_importances)[::-1]
for rank, feature in enumerate(ranking):
print("feature #" + str(rank + 1) + " = " + df.columns[feature])
fig, axes = plt.subplots(nrows = len(iris.target_names) + 1, ncols=1, figsize=(8,8))
print("Accumulated Feature importance = " + str(accumulated_feature_importances[feature]))
x_max = max(list(criteria_distribution[feature]) + list(df[df.columns[feature]]))
x_min = min(list(criteria_distribution[feature]) + list(df[df.columns[feature]]))
axes[0].hist(criteria_distribution[feature], bins=20, label="criteria")
axes[0].legend()
axes[0].set_xlim([x_min, x_max])
for target, target_name in enumerate(iris.target_names):
axes[target + 1] = plt.subplot(len(iris.target_names) + 1, 1, target + 2)
axes[target + 1].hist(list(df[df[df.columns[-1]] == target_name][df.columns[feature]]), bins=10,
label=target_name, alpha=0.5)
axes[target + 1].legend()
axes[target + 1].set_xlim([x_min, x_max])
plt.xlabel(df.columns[feature])
plt.show()
feature #1 = petal width (cm)
Accumulated Feature importance = 4.45000008496
feature #2 = petal length (cm)
Accumulated Feature importance = 4.30410137014
feature #3 = sepal length (cm)
Accumulated Feature importance = 0.976500917131
feature #4 = sepal width (cm)
Accumulated Feature importance = 0.269397627763