LoginSignup
2
1

Optunaでパラメータを指定して実行する,studyを継ぎ足す

Posted at

ハイパーパラメータの自動最適化フレームワークである Optuna は目的関数を書くと自動で最適なハイパーパラメータを探索してくれます.
基本的にはアルゴリズムに任せるのが良いですが,ユーザー側でパラメータをいじるなど変なことをしようとすると少々面倒なのでまとめておきます.

使用環境は
Python 3.9.16
Optuna 3.1.0

例:LightGBM モデルのハイパーパラメータ調整(探索のターゲットや数値は適当です)

import optuna

import lightgbm as lgb
from sklearn.model_selection import train_test_split

# データ読み込み
train_X = pd.read_pickle("../data/train_X.pkl").values
train_y = pd.read_pickle("../data/train_y.pkl").values
tra_x, val_x, tra_y, val_y = train_test_split(train_X, train_y, test_size=0.2, random_state=42)

# 目的関数の設定
def objective(trial):
    param = {
        'objective': 'regression',
        'metric': 'rmse',
        'boosting_type': 'gbdt',
        'verbosity': -1,
        'seed': 42,
        'num_leaves': trial.suggest_int('num_leaves', 5, 500, log=True),
        'learning_rate': trial.suggest_float('learning_rate', 0.0001, 0.1, log=True)
    }

    lgb_train = lgb.Dataset(tra_x, label=tra_y)
    lgb_val = lgb.Dataset(val_x, label=val_y)

    model = lgb.train(param, lgb_train, num_boost_round=1000, valid_sets=[lgb_train, lgb_val],
                      early_stopping_rounds=50, verbose_eval=False)

    val_pred = model.predict(val_x)
    rmse = mean_squared_error(val_y, val_pred, squared=False)

    return rmse
# 最適化の実行
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)
# => 100回試行が行われる

objective 関数内の trial.suggest_~~ で指定したパラメータの範囲・分布の中で,なるべく少ない試行回数で最も良い関数値となるパラメータを探索します.
Optuna は基本的に TPE というアルゴリズムを使用して1,これまでの試行結果から次に探索すべき点を提案してループを回していきます.

指定したパラメータで試行を行う

Optuna は目的関数で指定されたパラメータ空間内でよしなにパラメータを探索していきますが,「予め大体どのあたりのパラメータが良いか見当がついている」というようなシチュエーションなら先にそこを調べてしまったほうが良いでしょう.
手動で探索するパラメータを指定するには study.enqueue_trial(params) を使います.dict でパラメータを指定すると探索のキューに入るのでその後 study.optimize を実行すると指定した点が優先的に実行されます.

study = optuna.create_study(direction='minimize')
study.enqueue_trial(
    {'num_leaves': 40, 'learning_rate': 0.01}, user_attrs={"memo": "init"})
study.enqueue_trial(
    {'num_leaves': 50, 'learning_rate': 0.001}, user_attrs={"memo": "init"})
study.optimize(objective, n_trials=3)
# =>
# 1. {'num_leaves': 40, 'learning_rate': 0.01} で試行
# 2. {'num_leaves': 50, 'learning_rate': 0.001} で試行
# 3. 以降は最適化アルゴリズムに従った次のパラメータで試行

なお,オプション引数として user_attrs を指定すると探索とは関係ない変数を設定することができるのでメモなどに使えます.

参考:パラメータの初期値を指定してOptunaで最適化してみる │ キヨシの命題

探索済みのデータを新たなstudyに引き継ぐ

一度適当に最適化を実行してみたけれどよく考えたら log 空間に指定すべきだった,というように,途中でパラメータ空間を変更して探索をやり直したいことがあるかもしれません2
せっかく試行を行ったのにデータを捨ててしまうのはもったいないので,これまでの試行結果も踏まえて予測をしてほしいところですが,単純に既存の study で study.optimize に新しい目的関数を与えるとパラメータ空間が違うということでエラーになってしまいます.

study = optuna.create_study(direction='minimize')
study.optimize(objective1, n_trials=100)
# => 結果を見てパラメータ空間を変更

study.optimize(objective2, n_trials=100)
# => "ValueError: Cannot set different log configuration to the same parameter name." が発生

trial の形が違うので distributions を指定して trial を作り直してやる必要があります.

study = optuna.create_study(direction='minimize')

# old_study から試行結果を引き継ぐ
for t in old_study.trials:
    study.add_trial(
        optuna.trial.create_trial(
            params=t.params,
            distributions={
                'num_leaves': optuna.distributions.IntDistribution(5, 500, log=True),
                'learning_rate': optuna.distributions.FloatDistribution(0.0001, 0.1, log=False)
            },
            value=t.value
        )
    )

study.optimize(objective, n_trials=100)
# => old_study の結果も踏まえた探索を実行

参考:optunaで探索済みのデータをtrialに渡して重要度を評価する方法 - Qiita

目的関数で書いた分布を distributions に直してわざわざ打ち込むのは面倒ですが,試行を実行しないと optimize 関数の中身を見てくれないようなので3,先に1回実行してしまうと記述が楽です.

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=1)
# =>
# 自動で決まる最初のパラメータで試行

# old_study から試行結果を引き継ぐ
trial0 = study.trials[0]
for t in old_study.trials:
    trial = trial0
    trial.params = t.params
    trial.value = t.value
    study.add_trial(trial)

study.optimize(objective, n_trials=100)
# => old_study の結果も踏まえた探索を実行
  1. 機械学習向けハイパーパラメータ自動最適化フレームワーク Optunaの最新版となるv3を公開 - 株式会社Preferred Networks

  2. 実のところ,パラメータの上限下限の変更や新たなパラメータ追加などは単に新しい目的関数で optimize を実行すればそのまま動くので,そんなによくあるシチュエーションでもないです.Uniform ⇔ LogUniform の切り替えくらい.

  3. 途中で止めようにもどこで止めるのか判別できないので実現できなさそうですが,ちゃんと探せば静的に解析してくれる機能があるかもしれません.あったら教えていただけるとありがたいです.

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1