LoginSignup
1
0

【LightGBM】lgb.trainメソッドで学習したモデルのPermutation Importanceの計算方法

Posted at

概要

LightGBMで学習したモデルのPermutation Importanceを計測したい場合に
例えば

model = lgb.LGBMRegressor(objective="regression")
gbm = model.fit(**params)

のようにfitメソッドを使ってトレーニングを行えばPermutation Importanceを計算できるのですが

model = lgb.train(params, lgb_train)

のようにtrainメソッドを使った場合に、Permutation Importanceを計算すると以下のエラーが出る問題があります。

TypeError: estimator should be an estimator implementing 'fit' method,  was passed

本記事では、このエラーの対処方法を記載します。

起こる事象

サンプルとして、ボストンデータセットでLightGBMのモデルを作成します

import lightgbm as lgb
import numpy as np
from lightgbm import LGBMRegressor
from sklearn import datasets
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split

# データセットの読み込み
boston = datasets.load_boston()
X = boston.data
y = boston.target

# データセットの分割
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# LightGBMデータセットに変換
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

# パラメータ設定
params = {
    "boosting_type": "gbdt",
    "objective": "regression",
    "metric": "rmse",
}

# モデルの学習
model = lgb.train(
    params,
    lgb_train,
    num_boost_round=20,
    valid_sets=lgb_eval,
    early_stopping_rounds=5,
)

このモデルに対しPermutation Importanceを計算します

result = permutation_importance(
    model,
    X_test,
    y_test,
    n_repeats=10,
    n_jobs=-1,
    random_state=0,
)

この時、以下のエラーが発生します

TypeError: estimator should be an estimator implementing 'fit' method,  was passed

どうやら、fitメソッドを持つモデル(LGBMRegressorなど)にしかPermutation Importanceは対応していないため、モデルの形式を変換する必要がありそうです。

対処方法

以下の方法で計算できるようになります。

regressor_model = LGBMRegressor()
regressor_model._Booster = model
regressor_model._n_features = X_test.shape[1]
regressor_model.fitted_ = True

改めてPermutation Importanceを計算します

result = permutation_importance(
    regressor_model,
    X_test,
    y_test,
    n_repeats=10,
    n_jobs=-1,
    random_state=0,
)

今度はエラーなくうまく行きました。
例えばoptunaのLightGBM Tunerを使った場合はtrainメソッドでしか学習ができないと思いますので、こちらの方法を採用すれば良いかと思います。
ちなみに、今回はscikit-learnのpermutation_importanceを利用しましたが、eli5でも同様のエラーは発生するため、同様の対処が必要になります。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0