LoginSignup
1
2

More than 3 years have passed since last update.

決定木(分類)のハイパーパラメータとチューニング2

Last updated at Posted at 2019-06-03

はじめに

 乳癌の腫瘍が良性であるか悪性であるかを判定するためのウィスコンシン州の乳癌データセットについて、決定木とハイパーパラメータのチューニングにより分類器を作成する。データはsklearnに含まれるもので、データ数は569、そのうち良性は212、悪性は357、特徴量は30種類ある。
 前回の記事で最適なハイパーパラメータを決定した。本記事ではmax_depthとAccuracyの関係と決定木の視覚化について記載する。

シリーズ

max_depthとAccuracyの関係

 max_depth以外のパラメータを前回最適化した値とし、max_depthを1~10まで振ったときのAccuracyを以下の図に示す。max_depthとは、木の深さの最大値を示す。max_depthが3~5の時、Accuracyは最大となるが、max_depthが6以降はAccuracyが低下していく。これは、max_depthを大きくしすぎたことによる過学習が起きているためと思われる。なお、プログラム上、Accuracyが最大となるmax_depthは3と判定している。

tree_graph.png

決定木の可視化

 決定木をdotファイルに変換し、Graphvizによって可視化した図をいかに示す。ハイパーパラメータは最適化した値に設定している。max_depthは3のため、3段階に分岐していることがわかる。評価するサンプル数は426あり、徐々にふるいにかけていく。最初のパラメータがWorst radiusのため、これが最も重要なパラメータと思われる。以下、条件についてTrue or Falseでふるいにかけていき、対象の腫瘍がbenign(良性)であるかmalignant(悪性)であるかを判定している。
tree.png

手順

乳癌データの読み込み
トレーニングデータ、テストデータの分離
条件設定
決定木の実行
決定木をdotファイルに変換
グラフのプロット

%%time
import matplotlib.pyplot as plt 
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

#乳癌データの読み込み
cancer_data = load_breast_cancer()

#トレーニングデータ、テストデータの分離
train_X, test_X, train_y, test_y = train_test_split(cancer_data.data, cancer_data.target, random_state=0)

#条件設定
max_score = 0
accuracy = []
depth_list = [i for i in range(1, 11)]

#決定木の実行
for depth in tqdm(depth_list):
    clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth, min_samples_split=4, min_samples_leaf=1, random_state=41)
    clf.fit(train_X, train_y)
    accuracy.append(clf.score(test_X, test_y))
    if max_score < clf.score(test_X, test_y):
        max_score = clf.score(test_X, test_y)
        depth_ = depth

#決定木をdotファイルに変換
clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth_, min_samples_split=4, min_samples_leaf=1, random_state=41)
clf.fit(train_X, train_y)
tree.export_graphviz(clf, out_file="tree.dot", feature_names=cancer_data.feature_names, class_names=cancer_data.target_names, filled=True)

#グラフのプロット
plt.plot(depth_list, accuracy)
plt.title("Accuracy change")
plt.xlabel("max_depth")
plt.ylabel("Accuracy")
plt.show()

print("max_depth:{}".format(depth_))
print("ベストスコア:{}".format(max_score))

出力

max_depth:3
ベストスコア:0.972027972027972
Wall time: 112 ms

おわりに

 max_depthがAccuracyに大きく影響することがわかった。また、決定木を可視化することで難解なAIの思考についての説明をなんとなくできた気がする。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2