LoginSignup
1
1

More than 3 years have passed since last update.

チュートリアルのとおりOptuna使ってみたけどわからなかったので少し調べてみた

Last updated at Posted at 2020-08-18

Optunaとは

PFNが開発したハイパーパラメータ自動最適化フレームワーク

機械学習モデルのハイパーパラメーターの最適化のために作られたベイズ最適化Package

何ができるのか

ベイズ最適化のTPEを用いた最適化を行う

シングルプロセスで手軽に使うことができる

多数のマシンで並列に学習することもできる

どう使うのか

import optuna
from sklearn.metrics import log_loss # 損失関数を計算するライブラリ

# Optunaで最適化するための関数を定義
# Optunaは、return(ここではscore)で返す値が最小になるようなハイパーパラメータを探索してくれる
def objective(trial):
    # ハイパーパラメータの探索範囲を設定
    # ↓はLightGBMのパラメータ
    # 参考:https://qiita.com/nabenabe0928/items/6b9772131ba89da00354
    #* objective:目的関数
    #  * regression:回帰を解く
    #  * binary:二値分類
    #  * multiclass:他クラス分類
    #* learning_rate:学習率。デフォルトは0.1。大きなnum_iterationsを取るときは小さなlearning_rateを取ると精度が上がる
    #  * num_leaves:木にある分岐の個数。デフォルトは31。大きくすると精度は上がるが過学習が進む。
    #  * max_bin:一つの分岐に入るデータ数の最大値。小さくすると細かく学習ができるので精度を上げられる。大きくするとざっくりとした学習となり一般性が上がる。
    params = {
        'objective': 'binary', # objectiveはbinaryで固定
        'max_bin': trial.suggest_int('max_bin', 255, 500), # max_binは255~500の間の整数値を探索
        'learning_rate': 0.05, # learning_rateは0.05で固定
        'num_leaves': trial.suggest_int('num_leaves', 32, 128), # num_leavesは32~128の間の整数値を探索
    }

    # LightGBMを用いて学習
    lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features)
    lgb_eval = lgb.Dataset(X_valid, y_valid, reference=lgb_train, categorical_feature=categorical_features)

    model = lgb.train(
        params, lgb_train,
        valid_sets=[lgb_train, lgb_eval],
        verbose_eval=10,
        num_boost_round=1000,
        early_stopping_rounds=10
    )

    # 検証用データセットに対する性能を調べている
    y_pred_valid = model.predict(X_valid, num_iteration=model.best_iteration)
    score = log_loss(y_valid, y_pred_valid)
    return score
# Optuna最適化のためのセッションを作成。seed=0とすることで乱数を固定。
study = optuna.create_study(sampler=optuna.samplers.RandomSampler(seed=0))
# Optunaの計算を実行。第一引数に最小化したい関数を渡す。n_trialsは実行回数で今回は小さめの40回を指定している
study.optimize(objective, n_trials=40)
study.best_params
params = {
    'objective': 'binary',
    'max_bin': study.best_params['max_bin'],
    'learning_rate': 0.05,
    'num_leaves': study.best_params['num_leaves']
}

# LightGBMを用いて学習
lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features)
lgb_eval = lgb.Dataset(X_valid, y_valid, reference=lgb_train, categorical_feature=categorical_features)

model = lgb.train(
    params, lgb_train,
    valid_sets=[lgb_train, lgb_eval],
    verbose_eval=10,
    num_boost_round=1000,
    early_stopping_rounds=10
)

y_pred = model.predict(X_test, num_iteration=model.best_iteration)
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1