はじめに
LightGBMで初めてモデル作成するときに、APIが2種類あって混乱することがあると思います。本記事で実装の差異に注目してモデルを作成しようと思います。
こんな方におすすめ
・データ分析初心者
・とりあえずLightGBM使ってコンペに参加しようと考えている人
1. データ準備
Breast Cancer Wisconsin (Diagnostic) Data Setという乳がんの診断結果をもとに、患者が良性か悪性かを分類する二値分類のデータセットを用います。また、評価指標にはAccuracyを用います。
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import lightgbm as lgb
# データセットをロードする
data = load_breast_cancer()
# シード
seed = 42
# 説明変数をXに、目的変数をyに格納する
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
ちなみにseedに42がよく使われる理由は下の記事を参考に
2. Training APIで実装
大まかには、Datasetオブジェクトの作成 -> パラメータ辞書の作成 -> train()で学習といった流れになります。
Datasetオブジェクトの作成とtrainで学習するところがScikit-learn APIと異なる点です。
# Datasetオブジェクトの作成 (異なるポイント①)
d_train = lgb.Dataset(data=X_train,label=y_train)
d_test = lgb.Dataset(data=X_test,label=y_test)
# パラメータ設定
params = {
'objective':'binary', # 二値分類タスク(回帰であればregression)
'metric':'binary_logloss', # objectiveを定義しているので自動入力されるが、明示しておく
'n_estimators':10000, # early_stopping使用時は大きな値
'verbosity': -1, # コマンドライン出力しない設定
'random_state':seed
}
# モデルの構築・学習 (異なるポイント②)
gbm = lgb.train(params, d_train, valid_sets=[d_test], # early_stoppingの評価用データ
callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=True)] # early_stopping用コールバック関数
)
# テストデータで予測
y_pred = gbm.predict(X_test)
y_pred = y_pred.round(0) # 丸め込み
# テストデータの評価
accuracy = accuracy_score(y_pred,y_test)
print(accuracy)
# 0.9736842105263158
※ early_stoppingを指定することで、過学習の防止や時間の短縮に繋がります。
設定できるパラメータについて知りたい方は、以下の記事やドキュメントを参考に
3. Scikit-learn APIで実装
大まかにはパラメータ辞書の作成 -> LGBMモデルのインスタンス作成 -> fit()で学習といった流れになります。
LGBMモデルのインスタンス作成とfitで学習するところがTraining APIと異なる点です
# パラメータ設定
params = {
'objective':'binary',
'metric':'binary_logloss',
'n_estimators':10000,
'verbosity': -1,
'random_state':seed
}
# インスタンスの作成 (異なるポイント①)
gbm = lgb.LGBMClassifier(**params) # 回帰であればlgb.LGBMRegressor()
# モデルの学習 (異なるポイント②)
gbm.fit(X_train, y_train, eval_metric='bainry_logloss', eval_set=[(X_test, y_test)],
callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=True)]
)
# テストデータで予測
y_pred = gbm.predict(X_test)
y_pred = y_pred.round(0) # 丸め込み
# テストデータの評価
accuracy = accuracy_score(y_pred,y_test)
print(accuracy)
# 0.9736842105263158
4. まとめ
今回はLightGBMの分類モデルの作成方法を、APIに着目してシンプルにまとめてみました。今回はホールドアウトで評価していますが、クロスバリデーションを行う場合もTraining APIとScikit-learn APIで異なります。下の記事が参考になります。
最後までご覧いただきありがとうございました!