4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

カテゴリ値を含むデータでのLightGBMを使った分析(3本目:ハイパーパラメータ・チューニング編)

Posted at

はじめに

過去の投稿12では、Kaggleのデータセット「Adult Census Income3」を使って二値分類問題にチャレンジしました。
その際、カテゴリ値の取扱い方法としては、Label EncodingOne-hot Encodingを採用した場合を紹介しました。

が、ハイパーパラメータ・チューニングには取り組んでいませんでした。

そこで本稿では、Optunaを利用したLightGBMのハイパーパラメータ・チューニングにチャレンジします(問題設定は過去の投稿に同じ)。

本稿で紹介すること

  • ハイパーパラメータ・チューニング
  • 乱数Seedを指定した試行の自動化

過去の投稿からの変更部分を中心に、説明します。
が、紹介するPythonコードはリファクタリング未済のため、PEP84には則っていません!

本稿で紹介しないこと

  • データの収集および整形
  • 特徴量設計
  • 予測モデルの作成
  • 予測精度の確認
  • 特徴量重要度の確認(可視化)

分析環境の各種情報

筆者は、Windows 11のホスト上でJupyterを起動しています。(コンテナは不使用)

  • Python 3.9.13
  • pip 23.0.1
  • pandas 1.5.3
  • numpy 1.24.2
  • scikit-learn 1.2.1
  • lightgbm 3.3.5
  • optuna 3.1.0
  • matplotlib 3.6.3

ハイパーパラメータ・チューニング

こちらのサイト5で公開されているPythonコードを参考にしました。
LightGBMTunerのインスタンス化、学習(≒探索)、最適モデル(およびパラメータ)取得などの流れは同じです。

ポイントは、optuna.integration.lightgbm.LightGBMTuner6を利用することです。

# LightGBMのハイパーパラメータ
params = {
    # 勾配ブースティング木(GBDT:Gradient Boosting Decision Tree)
    'boosting_type': 'gbdt',
    # 二値分類問題
    'objective': 'binary',
    # AUCの最大化を目指す
    'metric': 'auc',
    # 正答率の最大化を目指す
    #'metric': 'binary_error',
    # Fatal の場合出力
    'verbosity': -1,
    # 学習率
    'learning_rate': 0.02,
    'deterministic': True, # 再現性確保用のパラメータ
    'force_row_wise': True  # 再現性確保用のパラメータ
}

booster = lgb.LightGBMTuner(
    params = params,
    train_set = lgb_train,
    valid_sets = lgb_valid,
    verbose_eval = 50,  # 50イテレーション毎に学習結果出力
    num_boost_round = 10000,  # 最大イテレーション回数指定
    early_stopping_rounds = 100,
    categorical_feature = columns_cat, # categorical_featureを設定
    optuna_seed = 31, # 再現性確保用のパラメータ
)

# 上記のパラメータでモデルを学習する
booster.run()

↓↓↓ココまでの成果↓↓↓

乱数Seedを指定した試行の自動化

データ分析も数値実験の扱いなので、筆者的は手放しで実行させました。
結果をファイル出力しておき、後にハイパーパラメータや精度指標を確認できるようにしました。

予測モデルを丸ごとpickleで保管・再実行することも考えたのですが、探索範囲が広いとDisk領域もそこそこ必要だし、乱数Seedさえ手元に残せば再現可能なので、筆者はこの方式としています。
forループを抜けてから一括でCSV出力としていますが、慎重に進めたい場合はforループの中で追記モードのCSV出力とするのが良いかと。

# 乱数Seedを指定した試行
# LightGBMTunerのインスタンス化、学習(≒探索)、最適モデル(およびパラメータ)取得などの流れは前節に同じ
# 更にモデル精度指標算出を追加
def autotuning(seed=31):
    booster = lgb.LightGBMTuner(
        params = params,
        train_set = lgb_train,
        valid_sets = lgb_valid,
        verbose_eval = 50,  # 50イテレーション毎に学習結果出力
        num_boost_round = 10000,  # 最大イテレーション回数指定
        early_stopping_rounds = 100,
        categorical_feature = columns_cat, # categorical_featureを設定
        optuna_seed = seed, # 再現性確保用のパラメータ
    )

    # 上記のパラメータでモデルを学習する
    booster.run()

    # 最適なパラメータの表示
    #print("Best params:", booster.best_params)

    # 最適なモデル(Boosterオブジェクト)を取得する
    model = booster.get_best_booster()

    # 最適なパラメータの表示
    best_params = model.params
    print("Best params:", best_params)

    # テストデータを予測する
    y_pred = model.predict(X_test, num_iteration=model.best_iteration)

    # AUCを計算
    fpr, tpr, thresholds = roc_curve(np.asarray(y_test), y_pred)
    #print("AUC", auc(fpr, tpr))

    metrics = {
        "seed": seed,
        "auc": auc(fpr, tpr),
        "accuracy": accuracy_score(np.asarray(y_test), np.round(y_pred)),
        "precision": precision_score(np.asarray(y_test), np.round(y_pred)),
    }

    # Tupleとして戻り値に
    return (best_params, metrics)

# 乱数シードを指定(from 素数を取得するワンライナー)
seeds = (lambda n:[x for x in range(2, n) if not 0 in map(lambda z: x%z, range(2, x))])(200)

# 各試行での情報を保存
listParams = []
listMetrics = []

# 乱数Seedを指定した試行を自動化
for s in seeds:
    print("### seed:%3s ###" %s)
    tunedparam, metrics = autotuning(s)
    tunedparam["seed"] = s
    listParams.append(tunedparam)
    listMetrics.append(metrics)

# DataFrameからCSVファイルを生成
# encoding="SJIS"だとJupyterLab(CSVTable)上で表示不可なことに注意
df_params = pd.DataFrame(listParams)
df_params.to_csv('lightgbm_label_params.csv', index=False, header=True, encoding='UTF-8', sep=',', quotechar='"', quoting=csv.QUOTE_ALL)
df_metrics = pd.DataFrame(listMetrics)
df_metrics.to_csv('lightgbm_label_metrics.csv', index=False, header=True, encoding='UTF-8', sep=',', quotechar='"', quoting=csv.QUOTE_ALL)

↓↓↓ココまでの成果↓↓↓

実行結果例

200未満の素数を乱数Seedに指定して実行すると、こんな感じ。
Label Encodingの時は、乱数Seed:=127で最も良い結果となりました。
一方、One-hot Encodingの時は、乱数Seed:=67で最も良い結果となりました。

Label Encoding(w/Categorical_Feature)の方かなと予想していましたが、One-hot Encodingに軍配が上がりました。
以下、CSV出力(モデルの精度指標のみ)の結果です。

Label Encoding
seed auc accuracy precision
2 0.8938513791435835 0.85513011768606 0.6797814207650273
3 0.8944455631108564 0.8549643626719708 0.6798245614035088
5 0.8943227679593286 0.8567876678269518 0.6885428253615128
7 0.8948486272242101 0.8554616277142384 0.6853303471444568
11 0.8945356587826498 0.8556273827283275 0.6836283185840708
13 0.8944940099716203 0.8556273827283275 0.6828193832599119
17 0.8941711456348777 0.85513011768606 0.6829810901001112
19 0.8946069952792277 0.856124647770595 0.6871508379888268
23 0.8942950594528173 0.8572849328692193 0.6887417218543046
29 0.8941614218422198 0.8554616277142384 0.6849162011173184
31 0.8941373274887316 0.8554616277142384 0.6861642294713161
37 0.8946421902455731 0.8556273827283275 0.6840354767184036
41 0.8940544601229845 0.8556273827283275 0.6840354767184036
43 0.893948961275211 0.8556273827283275 0.6816192560175055
47 0.8944086471192618 0.8554616277142384 0.6828729281767956
53 0.8946586260367027 0.8549643626719708 0.6826280623608018
59 0.8944940099716203 0.8557931377424167 0.6835722160970231
61 0.8938493999645468 0.8557931377424167 0.6831683168316832
67 0.8945473617543442 0.8559588927565058 0.6847345132743363
71 0.8946022624597925 0.8554616277142384 0.6824696802646086
73 0.8945383263717863 0.85513011768606 0.680968096809681
79 0.8944894492547101 0.8556273827283275 0.6836283185840708
83 0.8943404945193951 0.8538040775733466 0.6797312430011199
89 0.8941464489225521 0.8554616277142384 0.6828729281767956
97 0.8938825297005931 0.8543013426156141 0.6784140969162996
101 0.8944909121261718 0.8559588927565058 0.6835164835164835
103 0.8945502014460055 0.8554616277142384 0.6845039018952063
107 0.8944299878323516 0.8546328526437925 0.6827354260089686
109 0.8942384377221198 0.8547986076578816 0.678298800436205
113 0.8944550287497268 0.8543013426156141 0.6820224719101123
127 0.894999819292349 0.8554616277142384 0.683684794672586
131 0.8943819712279 0.8562904027846843 0.6850220264317181
137 0.8943342988284981 0.8562904027846843 0.6854304635761589
139 0.894206340601223 0.8547986076578816 0.6826815642458101
149 0.8943289636502257 0.856124647770595 0.683461117196057
151 0.894111856315044 0.8544670976297033 0.6823793490460157
157 0.8944799836158397 0.8562904027846843 0.6846153846153846
163 0.8942234648024521 0.856124647770595 0.6838638858397366
167 0.8941788041972364 0.8564561577987734 0.6870144284128746
173 0.8941963586547778 0.8547986076578816 0.681465038845727
179 0.8942485917710898 0.8554616277142384 0.6828729281767956
181 0.8941140075966054 0.8556273827283275 0.6832229580573952
191 0.8944329996265374 0.8559588927565058 0.683920704845815
193 0.8943032343227507 0.8566219128128626 0.6840958605664488
197 0.8944410023939461 0.8544670976297033 0.6823793490460157
199 0.8940823407320209 0.85513011768606 0.6829810901001112
One-hot Encoding
seed auc accuracy precision
2 0.8944794673082648 0.8534725675451682 0.6786114221724524
3 0.8941294968238479 0.8554616277142384 0.6853303471444568
5 0.8943444528774681 0.8557931377424167 0.6831683168316832
7 0.894542973139959 0.8556273827283275 0.6816192560175055
11 0.8951063507552719 0.8562904027846843 0.6842105263157895
13 0.8944831675125506 0.8559588927565058 0.6815217391304348
17 0.894911014389492 0.8549643626719708 0.6786492374727668
19 0.8946778154682308 0.8546328526437925 0.6831460674157304
23 0.8945323888346766 0.8546328526437925 0.6803097345132744
29 0.8944440141881322 0.8552958727001492 0.6797385620915033
31 0.8946890881836128 0.8546328526437925 0.6803097345132744
37 0.8942121060358075 0.8562904027846843 0.6854304635761589
41 0.8944607941843113 0.8552958727001492 0.6829268292682927
43 0.8940910319095292 0.8536383225592574 0.6769911504424779
47 0.8942133968047445 0.8549643626719708 0.6846846846846847
53 0.8945319585783642 0.8538040775733466 0.6773480662983425
59 0.8943666541031823 0.8566219128128626 0.6894618834080718
61 0.8942538408980999 0.8552958727001492 0.6817180616740088
67 0.8951266588532122 0.8556273827283275 0.6832229580573952
71 0.8937288421458431 0.8544670976297033 0.6795580110497238
73 0.8944449607520192 0.8547986076578816 0.6822742474916388
79 0.8939726253723869 0.8552958727001492 0.6825221238938053
83 0.8946227426602575 0.8543013426156141 0.675704989154013
89 0.8947935544162368 0.8569534228410409 0.6914414414414415
97 0.8942900684795947 0.8543013426156141 0.6796008869179601
101 0.8941228708766386 0.8556273827283275 0.6836283185840708
103 0.8943894576877338 0.8554616277142384 0.6845039018952063
107 0.8943374827252092 0.8534725675451682 0.6766334440753046
109 0.8947576710397919 0.856124647770595 0.6863181312569522
113 0.894577565747467 0.8549643626719708 0.6778741865509761
127 0.8937288421458431 0.8544670976297033 0.6795580110497238
131 0.8945638835967362 0.85513011768606 0.6789989118607181
137 0.8943548650802254 0.8533068125310791 0.6774553571428571
139 0.8944597615691621 0.8543013426156141 0.6780219780219781
149 0.8940351846401939 0.8546328526437925 0.6791208791208792
151 0.8939071403616562 0.8557931377424167 0.6823658269441402
157 0.8947278973029814 0.8546328526437925 0.6799116997792495
163 0.8943990093778665 0.8549643626719708 0.6818181818181818
167 0.8943369664176344 0.8549643626719708 0.6822222222222222
173 0.8948788312173328 0.85513011768606 0.6817679558011049
179 0.8938926837495634 0.8557931377424167 0.6839779005524862
181 0.894712063870689 0.8539698325874358 0.676923076923077
191 0.8942296604933492 0.8544670976297033 0.6803551609322974
193 0.8946532908584303 0.8571191778551301 0.6871569703622393
197 0.8946267870695931 0.8544670976297033 0.679162072767365
199 0.8947976848768349 0.8543013426156141 0.6792035398230089

Label Encodingの方式で最も良い精度のモデルを獲得できる乱数Seedで、改めて試行し精度や特徴量重要度を確認しました。
乱数シードの違いでAUCは最大約1%の差が出ましたが、特徴量重要度の上位ランクインした特徴量に差はありませんでした。

image.png

image.png

↓↓↓ココまでの成果↓↓↓

Notebook(Pythonコード)

GitHubで公開しています。このタイミングですが、参考サイトの方々に感謝申し上げます。

まとめ

Optunaを利用してLightGBMのハイパーパラメータ・チューニングをする方法を紹介しました。
AutoML、例えばTPOT7やPyCaret8を利用したら、モデル構築は更に省エネ化できそうだなと実感しました。。。

  1. カテゴリ値を含むデータでのLightGBMを使った分析 | https://qiita.com/Blaster36/items/d0a99f9d4f6b12a2dd83

  2. カテゴリ値を含むデータでのLightGBMを使った分析(2本目:One-hot Encoding編) | https://qiita.com/Blaster36/items/8461ec0dca662750eb71

  3. Adult Census Incomeh | https://www.kaggle.com/datasets/uciml/adult-census-income

  4. PEP 8 – Style Guide for Python Code | https://peps.python.org/pep-0008/

  5. LightGBMでOptunaを使用するときの再現性確保について | https://book-read-yoshi.hatenablog.com/entry/2021/03/22/lightgbm_optuna_deterministic

  6. optuna.integration.lightgbm.LightGBMTuner | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.integration.lightgbm.LightGBMTuner.html

  7. TPOT | http://automl.info/tpot/

  8. PyCaret | https://pycaret.org/

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?