LoginSignup
1
5

More than 3 years have passed since last update.

決定木のハイパーパラメータをOptunaで最適化する

Last updated at Posted at 2019-09-14

機械学習をする際に自身で決めなければならないパラメータを、ハイパーパラメータと言います。

Optunaは機械学習のために設計されたハイパーパラメータ最適化フレームワークです。

今回はsickit-learnのDecisionTreeClassifierのハイパーパラメータチューニングを試みます。

以下の公式ドキュメントに各ハイパーパラメータの詳細があります。

irisデータセットでrandome_state以外何も指定せずに実行してみます。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
X = iris.data
y = iris.target
(X_train, X_test,
 y_train, y_test) = train_test_split(
    X, y, test_size=0.3, random_state=0,
)
clf.fit(X_train, y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=0, splitter='best')
print(clf.score(X_train, y_train))
print(clf.score(X_test, y_test))
1.0
0.9777777777777777

ここからがOptunaを用いたハイパーパラメータ最適化の部分です。

from sklearn.metrics import accuracy_score
def objective(trial):
    max_depth = trial.suggest_int("max_depth", 2, 612)
    min_samples_split = trial.suggest_int("min_samples_split", 2, 612)
    max_leaf_nodes = int(trial.suggest_int("max_leaf_nodes", 2, 612))
    criterion = trial.suggest_categorical("criterion", ["gini", "entropy"])

    DTC = DecisionTreeClassifier(min_samples_split = min_samples_split, 
                                max_leaf_nodes = max_leaf_nodes,
                                criterion = criterion)
    DTC.fit(iris.data, iris.target)
    return 1.0 - accuracy_score(y_test, DTC.predict(X_test))
!pip install optuna
import optuna
study = optuna.create_study()
study.optimize(objective, n_trials = 500)
print(study.best_params)
print(1.0 - study.best_value)
{'max_depth': 78, 'min_samples_split': 2, 'max_leaf_nodes': 539, 'criterion': 'gini'}
1.0

最適なパラメータは
{'max_depth': 78, 'min_samples_split': 2, 'max_leaf_nodes': 539, 'criterion': 'gini'}
であることがわかりました。

ソースコードはgithubでも公開しております。

1
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
5