17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LightGBM:Training APIとScikit-learn APIの違い

Last updated at Posted at 2023-03-27

はじめに

LightGBMで初めてモデル作成するときに、APIが2種類あって混乱することがあると思います。本記事で実装の差異に注目してモデルを作成しようと思います。

こんな方におすすめ
・データ分析初心者
・とりあえずLightGBM使ってコンペに参加しようと考えている人

1. データ準備

Breast Cancer Wisconsin (Diagnostic) Data Setという乳がんの診断結果をもとに、患者が良性か悪性かを分類する二値分類のデータセットを用います。また、評価指標にはAccuracyを用います。

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import lightgbm as lgb

# データセットをロードする
data = load_breast_cancer()

# シード
seed = 42

# 説明変数をXに、目的変数をyに格納する
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

ちなみにseedに42がよく使われる理由は下の記事を参考に

2. Training APIで実装

大まかには、Datasetオブジェクトの作成 -> パラメータ辞書の作成 -> train()で学習といった流れになります。
Datasetオブジェクトの作成trainで学習するところがScikit-learn APIと異なる点です。

# Datasetオブジェクトの作成 (異なるポイント①)
d_train = lgb.Dataset(data=X_train,label=y_train)
d_test = lgb.Dataset(data=X_test,label=y_test)

# パラメータ設定
params = {
    'objective':'binary', # 二値分類タスク(回帰であればregression)
    'metric':'binary_logloss', # objectiveを定義しているので自動入力されるが、明示しておく
    'n_estimators':10000, # early_stopping使用時は大きな値
    'verbosity': -1, # コマンドライン出力しない設定
    'random_state':seed
}

# モデルの構築・学習 (異なるポイント②)
gbm = lgb.train(params, d_train, valid_sets=[d_test], # early_stoppingの評価用データ
                callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=True)] # early_stopping用コールバック関数
               )

# テストデータで予測
y_pred = gbm.predict(X_test)
y_pred = y_pred.round(0) # 丸め込み

# テストデータの評価
accuracy = accuracy_score(y_pred,y_test)
print(accuracy)
# 0.9736842105263158

※ early_stoppingを指定することで、過学習の防止や時間の短縮に繋がります。
設定できるパラメータについて知りたい方は、以下の記事やドキュメントを参考に

3. Scikit-learn APIで実装

大まかにはパラメータ辞書の作成 -> LGBMモデルのインスタンス作成 -> fit()で学習といった流れになります。
LGBMモデルのインスタンス作成fitで学習するところがTraining APIと異なる点です

# パラメータ設定
params = {
    'objective':'binary',
    'metric':'binary_logloss',
    'n_estimators':10000,
    'verbosity': -1,
    'random_state':seed
}

# インスタンスの作成 (異なるポイント①)
gbm = lgb.LGBMClassifier(**params)  # 回帰であればlgb.LGBMRegressor()

# モデルの学習 (異なるポイント②)
gbm.fit(X_train, y_train, eval_metric='bainry_logloss', eval_set=[(X_test, y_test)],
        callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=True)]
       )

# テストデータで予測
y_pred = gbm.predict(X_test)
y_pred = y_pred.round(0) # 丸め込み

# テストデータの評価
accuracy = accuracy_score(y_pred,y_test)
print(accuracy)
# 0.9736842105263158

4. まとめ

今回はLightGBMの分類モデルの作成方法を、APIに着目してシンプルにまとめてみました。今回はホールドアウトで評価していますが、クロスバリデーションを行う場合もTraining APIとScikit-learn APIで異なります。下の記事が参考になります。

最後までご覧いただきありがとうございました!

17
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?