21
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

決定木(分類)のハイパーパラメータとチューニング

Last updated at Posted at 2019-06-03

はじめに

 乳癌の腫瘍が良性であるか悪性であるかを判定するためのウィスコンシン州の乳癌データセットについて、決定木とハイパーパラメータのチューニングにより分類器を作成する。データはsklearnに含まれるもので、データ数は569、そのうち良性は212、悪性は357、特徴量は30種類ある。

シリーズ

決定木とは

機械学習の分野においては決定木は予測モデルであり、ある事項に対する観察結果から、その事項の目標値に関する結論を導く。内部節点は変数に対応し、子節点への枝はその変数の取り得る値を示す。 葉(端点)は、根(root)からの経路によって表される変数値に対して、目的変数の予測値を表す。
(wikipediaより)

決定木のハイパーパラメータ

詳細は以下を参照されたい。
DicisionTrees

ハイパーパラメータ 選択肢 default
criterion gini、entropy gini
splitter best、random best
max_features int、float、string型 or None -
max_depth int型、None None
min_samples_split int型 2
min_samples_leaf int型 1
min_weight_fraction_leaf float型 0
max_leaf_nodes int型、None None
class_weight 辞書型、None None
random_state int型 None
presort bool型 False

手順

  • 乳癌データの読み込み
  • トレーニングデータ、テストデータの分離
  • 条件設定
  • 決定木の実行(グリッドサーチ)
  • ハイパーパラメータをチューニングしない場合との比較

pythonによる実装

%%time
from tqdm import tqdm
import scipy.stats
import matplotlib.pyplot as plt 
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

#乳癌データの読み込み
cancer_data = load_breast_cancer()

#トレーニングデータ、テストデータの分離
train_X, test_X, train_y, test_y = train_test_split(cancer_data.data, cancer_data.target, random_state=0)

#条件設定
max_score = 0
SearchMethod = 0
DTC_grid = {DecisionTreeClassifier(): {"criterion": ["gini", "entropy"],
                                       "splitter": ["best", "random"],
                                       "max_depth": [i for i in range(1, 11)],
                                       "min_samples_split": [i for i in range(2, 11)],
                                       "min_samples_leaf": [i for i in range(1, 11)],
                                       "random_state": [i for i in range(0, 101)]
                                      }}

#決定木の実行
for model, param in tqdm(DTC_grid.items()):
    clf = GridSearchCV(model, param)
    clf.fit(train_X, train_y)
    pred_y = clf.predict(test_X)
    score = f1_score(test_y, pred_y, average="micro")

    if max_score < score:
        max_score = score
        best_param = clf.best_params_
        best_model = model.__class__.__name__
        
print("ベストスコア:{}".format(max_score))
print("モデル:{}".format(best_model))
print("パラメーター:{}".format(best_param))

#ハイパーパラメータを調整しない場合との比較
model = DecisionTreeClassifier()
model.fit(train_X, train_y)
score = model.score(test_X, test_y)
print("")
print("デフォルトスコア:", score)

結果

100%|██████████████████████████████████████████| 1/1 [41:23<00:00, 2483.41s/it]
ベストスコア:0.951048951048951
モデル:DecisionTreeClassifier
パラメーター:{'criterion': 'entropy', 'max_depth': 7, 'min_samples_leaf': 1, 'min_samples_split': 4, 'random_state': 41, 'splitter': 'random'}

デフォルトスコア: 0.867132867133
Wall time: 41min 23s

おわりに

 ハイパーパラメータのチューニングにより、デフォルトよりも高い正解率を得ることができた。

21
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?