はじめに
過去の投稿12では、Kaggleのデータセット「Adult Census Income3」を使って二値分類問題にチャレンジしました。
その際、カテゴリ値の取扱い方法としては、Label EncodingとOne-hot Encodingを採用した場合を紹介しました。
が、ハイパーパラメータ・チューニングには取り組んでいませんでした。
そこで本稿では、Optunaを利用したLightGBMのハイパーパラメータ・チューニングにチャレンジします(問題設定は過去の投稿に同じ)。
本稿で紹介すること
- ハイパーパラメータ・チューニング
- 乱数Seedを指定した試行の自動化
過去の投稿からの変更部分を中心に、説明します。
が、紹介するPythonコードはリファクタリング未済のため、PEP84には則っていません!
本稿で紹介しないこと
- データの収集および整形
- 特徴量設計
- 予測モデルの作成
- 予測精度の確認
- 特徴量重要度の確認(可視化)
分析環境の各種情報
筆者は、Windows 11のホスト上でJupyterを起動しています。(コンテナは不使用)
- Python 3.9.13
- pip 23.0.1
- pandas 1.5.3
- numpy 1.24.2
- scikit-learn 1.2.1
- lightgbm 3.3.5
- optuna 3.1.0
- matplotlib 3.6.3
ハイパーパラメータ・チューニング
こちらのサイト5で公開されているPythonコードを参考にしました。
LightGBMTunerのインスタンス化、学習(≒探索)、最適モデル(およびパラメータ)取得などの流れは同じです。
ポイントは、optuna.integration.lightgbm.LightGBMTuner6を利用することです。
# LightGBMのハイパーパラメータ
params = {
# 勾配ブースティング木(GBDT:Gradient Boosting Decision Tree)
'boosting_type': 'gbdt',
# 二値分類問題
'objective': 'binary',
# AUCの最大化を目指す
'metric': 'auc',
# 正答率の最大化を目指す
#'metric': 'binary_error',
# Fatal の場合出力
'verbosity': -1,
# 学習率
'learning_rate': 0.02,
'deterministic': True, # 再現性確保用のパラメータ
'force_row_wise': True # 再現性確保用のパラメータ
}
booster = lgb.LightGBMTuner(
params = params,
train_set = lgb_train,
valid_sets = lgb_valid,
verbose_eval = 50, # 50イテレーション毎に学習結果出力
num_boost_round = 10000, # 最大イテレーション回数指定
early_stopping_rounds = 100,
categorical_feature = columns_cat, # categorical_featureを設定
optuna_seed = 31, # 再現性確保用のパラメータ
)
# 上記のパラメータでモデルを学習する
booster.run()
↓↓↓ココまでの成果↓↓↓
乱数Seedを指定した試行の自動化
データ分析も数値実験の扱いなので、筆者的は手放しで実行させました。
結果をファイル出力しておき、後にハイパーパラメータや精度指標を確認できるようにしました。
予測モデルを丸ごとpickleで保管・再実行することも考えたのですが、探索範囲が広いとDisk領域もそこそこ必要だし、乱数Seedさえ手元に残せば再現可能なので、筆者はこの方式としています。
forループを抜けてから一括でCSV出力としていますが、慎重に進めたい場合はforループの中で追記モードのCSV出力とするのが良いかと。
# 乱数Seedを指定した試行
# LightGBMTunerのインスタンス化、学習(≒探索)、最適モデル(およびパラメータ)取得などの流れは前節に同じ
# 更にモデル精度指標算出を追加
def autotuning(seed=31):
booster = lgb.LightGBMTuner(
params = params,
train_set = lgb_train,
valid_sets = lgb_valid,
verbose_eval = 50, # 50イテレーション毎に学習結果出力
num_boost_round = 10000, # 最大イテレーション回数指定
early_stopping_rounds = 100,
categorical_feature = columns_cat, # categorical_featureを設定
optuna_seed = seed, # 再現性確保用のパラメータ
)
# 上記のパラメータでモデルを学習する
booster.run()
# 最適なパラメータの表示
#print("Best params:", booster.best_params)
# 最適なモデル(Boosterオブジェクト)を取得する
model = booster.get_best_booster()
# 最適なパラメータの表示
best_params = model.params
print("Best params:", best_params)
# テストデータを予測する
y_pred = model.predict(X_test, num_iteration=model.best_iteration)
# AUCを計算
fpr, tpr, thresholds = roc_curve(np.asarray(y_test), y_pred)
#print("AUC", auc(fpr, tpr))
metrics = {
"seed": seed,
"auc": auc(fpr, tpr),
"accuracy": accuracy_score(np.asarray(y_test), np.round(y_pred)),
"precision": precision_score(np.asarray(y_test), np.round(y_pred)),
}
# Tupleとして戻り値に
return (best_params, metrics)
# 乱数シードを指定(from 素数を取得するワンライナー)
seeds = (lambda n:[x for x in range(2, n) if not 0 in map(lambda z: x%z, range(2, x))])(200)
# 各試行での情報を保存
listParams = []
listMetrics = []
# 乱数Seedを指定した試行を自動化
for s in seeds:
print("### seed:%3s ###" %s)
tunedparam, metrics = autotuning(s)
tunedparam["seed"] = s
listParams.append(tunedparam)
listMetrics.append(metrics)
# DataFrameからCSVファイルを生成
# encoding="SJIS"だとJupyterLab(CSVTable)上で表示不可なことに注意
df_params = pd.DataFrame(listParams)
df_params.to_csv('lightgbm_label_params.csv', index=False, header=True, encoding='UTF-8', sep=',', quotechar='"', quoting=csv.QUOTE_ALL)
df_metrics = pd.DataFrame(listMetrics)
df_metrics.to_csv('lightgbm_label_metrics.csv', index=False, header=True, encoding='UTF-8', sep=',', quotechar='"', quoting=csv.QUOTE_ALL)
↓↓↓ココまでの成果↓↓↓
実行結果例
200未満の素数を乱数Seedに指定して実行すると、こんな感じ。
Label Encodingの時は、乱数Seed:=127で最も良い結果となりました。
一方、One-hot Encodingの時は、乱数Seed:=67で最も良い結果となりました。
Label Encoding(w/Categorical_Feature)の方かなと予想していましたが、One-hot Encodingに軍配が上がりました。
以下、CSV出力(モデルの精度指標のみ)の結果です。
Label Encoding
seed | auc | accuracy | precision |
2 | 0.8938513791435835 | 0.85513011768606 | 0.6797814207650273 |
3 | 0.8944455631108564 | 0.8549643626719708 | 0.6798245614035088 |
5 | 0.8943227679593286 | 0.8567876678269518 | 0.6885428253615128 |
7 | 0.8948486272242101 | 0.8554616277142384 | 0.6853303471444568 |
11 | 0.8945356587826498 | 0.8556273827283275 | 0.6836283185840708 |
13 | 0.8944940099716203 | 0.8556273827283275 | 0.6828193832599119 |
17 | 0.8941711456348777 | 0.85513011768606 | 0.6829810901001112 |
19 | 0.8946069952792277 | 0.856124647770595 | 0.6871508379888268 |
23 | 0.8942950594528173 | 0.8572849328692193 | 0.6887417218543046 |
29 | 0.8941614218422198 | 0.8554616277142384 | 0.6849162011173184 |
31 | 0.8941373274887316 | 0.8554616277142384 | 0.6861642294713161 |
37 | 0.8946421902455731 | 0.8556273827283275 | 0.6840354767184036 |
41 | 0.8940544601229845 | 0.8556273827283275 | 0.6840354767184036 |
43 | 0.893948961275211 | 0.8556273827283275 | 0.6816192560175055 |
47 | 0.8944086471192618 | 0.8554616277142384 | 0.6828729281767956 |
53 | 0.8946586260367027 | 0.8549643626719708 | 0.6826280623608018 |
59 | 0.8944940099716203 | 0.8557931377424167 | 0.6835722160970231 |
61 | 0.8938493999645468 | 0.8557931377424167 | 0.6831683168316832 |
67 | 0.8945473617543442 | 0.8559588927565058 | 0.6847345132743363 |
71 | 0.8946022624597925 | 0.8554616277142384 | 0.6824696802646086 |
73 | 0.8945383263717863 | 0.85513011768606 | 0.680968096809681 |
79 | 0.8944894492547101 | 0.8556273827283275 | 0.6836283185840708 |
83 | 0.8943404945193951 | 0.8538040775733466 | 0.6797312430011199 |
89 | 0.8941464489225521 | 0.8554616277142384 | 0.6828729281767956 |
97 | 0.8938825297005931 | 0.8543013426156141 | 0.6784140969162996 |
101 | 0.8944909121261718 | 0.8559588927565058 | 0.6835164835164835 |
103 | 0.8945502014460055 | 0.8554616277142384 | 0.6845039018952063 |
107 | 0.8944299878323516 | 0.8546328526437925 | 0.6827354260089686 |
109 | 0.8942384377221198 | 0.8547986076578816 | 0.678298800436205 |
113 | 0.8944550287497268 | 0.8543013426156141 | 0.6820224719101123 |
127 | 0.894999819292349 | 0.8554616277142384 | 0.683684794672586 |
131 | 0.8943819712279 | 0.8562904027846843 | 0.6850220264317181 |
137 | 0.8943342988284981 | 0.8562904027846843 | 0.6854304635761589 |
139 | 0.894206340601223 | 0.8547986076578816 | 0.6826815642458101 |
149 | 0.8943289636502257 | 0.856124647770595 | 0.683461117196057 |
151 | 0.894111856315044 | 0.8544670976297033 | 0.6823793490460157 |
157 | 0.8944799836158397 | 0.8562904027846843 | 0.6846153846153846 |
163 | 0.8942234648024521 | 0.856124647770595 | 0.6838638858397366 |
167 | 0.8941788041972364 | 0.8564561577987734 | 0.6870144284128746 |
173 | 0.8941963586547778 | 0.8547986076578816 | 0.681465038845727 |
179 | 0.8942485917710898 | 0.8554616277142384 | 0.6828729281767956 |
181 | 0.8941140075966054 | 0.8556273827283275 | 0.6832229580573952 |
191 | 0.8944329996265374 | 0.8559588927565058 | 0.683920704845815 |
193 | 0.8943032343227507 | 0.8566219128128626 | 0.6840958605664488 |
197 | 0.8944410023939461 | 0.8544670976297033 | 0.6823793490460157 |
199 | 0.8940823407320209 | 0.85513011768606 | 0.6829810901001112 |
One-hot Encoding
seed | auc | accuracy | precision |
2 | 0.8944794673082648 | 0.8534725675451682 | 0.6786114221724524 |
3 | 0.8941294968238479 | 0.8554616277142384 | 0.6853303471444568 |
5 | 0.8943444528774681 | 0.8557931377424167 | 0.6831683168316832 |
7 | 0.894542973139959 | 0.8556273827283275 | 0.6816192560175055 |
11 | 0.8951063507552719 | 0.8562904027846843 | 0.6842105263157895 |
13 | 0.8944831675125506 | 0.8559588927565058 | 0.6815217391304348 |
17 | 0.894911014389492 | 0.8549643626719708 | 0.6786492374727668 |
19 | 0.8946778154682308 | 0.8546328526437925 | 0.6831460674157304 |
23 | 0.8945323888346766 | 0.8546328526437925 | 0.6803097345132744 |
29 | 0.8944440141881322 | 0.8552958727001492 | 0.6797385620915033 |
31 | 0.8946890881836128 | 0.8546328526437925 | 0.6803097345132744 |
37 | 0.8942121060358075 | 0.8562904027846843 | 0.6854304635761589 |
41 | 0.8944607941843113 | 0.8552958727001492 | 0.6829268292682927 |
43 | 0.8940910319095292 | 0.8536383225592574 | 0.6769911504424779 |
47 | 0.8942133968047445 | 0.8549643626719708 | 0.6846846846846847 |
53 | 0.8945319585783642 | 0.8538040775733466 | 0.6773480662983425 |
59 | 0.8943666541031823 | 0.8566219128128626 | 0.6894618834080718 |
61 | 0.8942538408980999 | 0.8552958727001492 | 0.6817180616740088 |
67 | 0.8951266588532122 | 0.8556273827283275 | 0.6832229580573952 |
71 | 0.8937288421458431 | 0.8544670976297033 | 0.6795580110497238 |
73 | 0.8944449607520192 | 0.8547986076578816 | 0.6822742474916388 |
79 | 0.8939726253723869 | 0.8552958727001492 | 0.6825221238938053 |
83 | 0.8946227426602575 | 0.8543013426156141 | 0.675704989154013 |
89 | 0.8947935544162368 | 0.8569534228410409 | 0.6914414414414415 |
97 | 0.8942900684795947 | 0.8543013426156141 | 0.6796008869179601 |
101 | 0.8941228708766386 | 0.8556273827283275 | 0.6836283185840708 |
103 | 0.8943894576877338 | 0.8554616277142384 | 0.6845039018952063 |
107 | 0.8943374827252092 | 0.8534725675451682 | 0.6766334440753046 |
109 | 0.8947576710397919 | 0.856124647770595 | 0.6863181312569522 |
113 | 0.894577565747467 | 0.8549643626719708 | 0.6778741865509761 |
127 | 0.8937288421458431 | 0.8544670976297033 | 0.6795580110497238 |
131 | 0.8945638835967362 | 0.85513011768606 | 0.6789989118607181 |
137 | 0.8943548650802254 | 0.8533068125310791 | 0.6774553571428571 |
139 | 0.8944597615691621 | 0.8543013426156141 | 0.6780219780219781 |
149 | 0.8940351846401939 | 0.8546328526437925 | 0.6791208791208792 |
151 | 0.8939071403616562 | 0.8557931377424167 | 0.6823658269441402 |
157 | 0.8947278973029814 | 0.8546328526437925 | 0.6799116997792495 |
163 | 0.8943990093778665 | 0.8549643626719708 | 0.6818181818181818 |
167 | 0.8943369664176344 | 0.8549643626719708 | 0.6822222222222222 |
173 | 0.8948788312173328 | 0.85513011768606 | 0.6817679558011049 |
179 | 0.8938926837495634 | 0.8557931377424167 | 0.6839779005524862 |
181 | 0.894712063870689 | 0.8539698325874358 | 0.676923076923077 |
191 | 0.8942296604933492 | 0.8544670976297033 | 0.6803551609322974 |
193 | 0.8946532908584303 | 0.8571191778551301 | 0.6871569703622393 |
197 | 0.8946267870695931 | 0.8544670976297033 | 0.679162072767365 |
199 | 0.8947976848768349 | 0.8543013426156141 | 0.6792035398230089 |
Label Encodingの方式で最も良い精度のモデルを獲得できる乱数Seedで、改めて試行し精度や特徴量重要度を確認しました。
乱数シードの違いでAUCは最大約1%の差が出ましたが、特徴量重要度の上位ランクインした特徴量に差はありませんでした。
↓↓↓ココまでの成果↓↓↓
Notebook(Pythonコード)
GitHubで公開しています。このタイミングですが、参考サイトの方々に感謝申し上げます。
まとめ
Optunaを利用してLightGBMのハイパーパラメータ・チューニングをする方法を紹介しました。
AutoML、例えばTPOT7やPyCaret8を利用したら、モデル構築は更に省エネ化できそうだなと実感しました。。。
-
カテゴリ値を含むデータでのLightGBMを使った分析 | https://qiita.com/Blaster36/items/d0a99f9d4f6b12a2dd83 ↩
-
カテゴリ値を含むデータでのLightGBMを使った分析(2本目:One-hot Encoding編) | https://qiita.com/Blaster36/items/8461ec0dca662750eb71 ↩
-
Adult Census Incomeh | https://www.kaggle.com/datasets/uciml/adult-census-income ↩
-
PEP 8 – Style Guide for Python Code | https://peps.python.org/pep-0008/ ↩
-
LightGBMでOptunaを使用するときの再現性確保について | https://book-read-yoshi.hatenablog.com/entry/2021/03/22/lightgbm_optuna_deterministic ↩
-
optuna.integration.lightgbm.LightGBMTuner | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.integration.lightgbm.LightGBMTuner.html ↩
-
TPOT | http://automl.info/tpot/ ↩
-
PyCaret | https://pycaret.org/ ↩