LoginSignup
8
8

More than 5 years have passed since last update.

XGBoostでOptunaを使う

Posted at

概要

Optunaを使ってXGBoostのハイパーパラメータ探索をしてみる。

インストール

pip install optuna

コード概要

import xgboost as xgb
import sklearn.metrics import accuracy_score
import optuna
import functools

# ポイント1
def opt(X_train, y_train, X_test, y_test, trial):
    """ optunaでのハイパーパラメータ探索用関数 """
    n_estimators = trial.suggest_int('n_estimators', 0, 1000)
    max_depth = trial.suggest_int('max_depth', 1, 20)
    min_child_weight = trial.suggest_int('min_child_weight', 1, 20)
    learning_rate = trial.suggest_discrete_uniform('learning_rate', 0.001, 0.01, 0.001)
    scale_pos_weight = trial.suggest_int('scale_pos_weight', 1, 100)
    subsample = trial.suggest_discrete_uniform('subsample', 0.5, 0.9, 0.1)
    colsample_bytree = trial.suggest_discrete_uniform('colsample_bytree', 0.5, 0.9, 0.1)

    xgboost_tuna = xgb.XGBClassifier(
        random_state=42, # ここは固定の数値
        n_estimators = n_estimators,
        max_depth = max_depth,
        min_child_weight = min_child_weight,
        learning_rate = learning_rate,
        scale_pos_weight = scale_pos_weight,
        subsample = subsample,
        colsample_bytree = colsample_bytree,
    )
    xgboost_tuna.fit(X_train, y_train)
    tuna_pred_test = xgboost_tuna.predict(X_test)
    # ポイント2
    return (1.0 - (accuracy_score(y_test, tuna_pred_test)))

def main():
    # data preparing
    X_train, y_train = foo()
    X_test, y_test = bar()

    clf = xgb.XGBClassifier()
    # ポイント3
    study = optuna.create_study()
    study.optimize(functools.partial(opt, feature_train, result_train, teature_test, result_test), n_trials=100)

    # ポイント4
    # study.optimizeで得られたパラメータをxgb_paramのdictに格納して、学習開始。
    clf = xgb.XGBClassifier(**study.best_params)
    clf.fit(feature_train, result_train)

ポイントメモ

ポイント1

optinaで利用する関数の箇所で、作るMLのモデルのハイパーパラメータの型や探索レンジを設定する。
suggest_intで整数型のハイパーパラメータを、suggest_discrete_uniformでfloat型のハイパーパラメータを指定。

他にどんなものを指定できるかは、ここ参照。

ポイント2

最小にしたいものをreturnに設定。
今回は正解と予測の差を最小にしたいため、accuracyを1から引いて表現。
もし、weight付きで最小にしたいなどのやりたいことがあれば、ここの式をいじってあげればそれに沿った探索が走る。

ポイント3

ポイント1で作成した関数を用いて探索。n_trialsで探索回数を指定。

ポイント4

探索した結果もっとも良いハイパーパラメータは、study.best_ best_paramsに格納されるため、これをxgboostのハイパーパラメータとしてセットして、fitをかければOK。

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8