16
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

LightGBMのカスタムメトリック

Posted at

カスタムメトリックの作り方

LightGBMのScikit-learn APIの場合のカスタムメトリックとして、4クラス分類のときのF1スコアを作ってみます。
(y_true, y_pred)を引数に持つ関数を作ればよいです。次のような関数になります。

from sklearn.metrics import f1_score

def f1(y_true, y_pred):
    N_LABELS = 4  # ラベルの数
    y_pred_ = y_pred.reshape(N_LABELS, len(y_pred) // N_LABELS).argmax(axis=0)
    score = f1_score(y_true, y_pred_, average='macro')
    return 'f1', score, True

返り値は、メトリックの名前、スコアの値、値が大きいほうが良いか否か、です。3つ目の返り値はEarly stoppingなどに使用されます。
また、他クラス分類のように予測が複数列を持つ場合は、y_predが1列で渡されるため、reshapeする必要があるのですが、並び替えの向きが行方向をラベルにする必要があるので気を付ける必要があります。

カスタムメトリックはeval_metricに渡してあげれば使用できます。

params= {
    'objective': 'multiclass',
    'num_class':4, 
    'learning_rate': .01, 
    'max_depth': 6,
    'n_estimators': 1000,
    'colsample_bytree': .7,
    'importance_type': 'gain',
}

clf = lgb.LGBMClassifier(**params)
clf.fit(X_train, y_train, 
        eval_set=[(X_valid, y_valid)],  
        early_stopping_rounds=100, 
        eval_metric=f1,
        verbose=verbose)

## 出力
# Training until validation scores don't improve for 100 rounds
# [100]	valid_0's multi_logloss: 0.103872	valid_0's f1: 0.92157
# [200]	valid_0's multi_logloss: 0.0416318	valid_0's f1: 0.971213
# ...

カスタムメトリックはあくまで参考値で、objectiveのみvalidationに使いたい場合はparams'first_metric_only': Trueを追加します。

逆にvalidationをカスタムメトリックのみで行いたい場合は、params"metric" : "None"を追加すればOKです。

16
8
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?