LoginSignup
1
2

More than 1 year has passed since last update.

Imbalanced-learn の不均衡データセットをLightGBMで学習して評価する

Posted at

不均衡データセットのベースラインモデル評価

Imbalanced-learnに実装されている27の不均衡なデータセットを使用する。ここでは、以下のことを行う。

  • 27のデータセットの読み込み
  • LightGBMによるベースラインモデルの評価

実行環境
このnotebookはGoogle Colaboratoryで実行されている。

!python --version
Python 3.7.13

ライブラリ
使用しているライブラリは以下の4つ。

import imblearn
import sklearn
import pandas as pd
import lightgbm as lgb

print(f"Imbalanced-learn version {imblearn.__version__}")
print(f"Scikit-learn version     {sklearn.__version__}")
print(f"Pandas version           {pd.__version__}")
print(f"LightGBM version         {lgb.__version__}")
Imbalanced-learn version 0.8.1
Scikit-learn version     1.0.2
Pandas version           1.3.5
LightGBM version         2.2.3

データセットの読み込み

Imbalanced-learnのベンチマークデータセットを使用する。27のデータセットはimblearn.datasets.fetch_datasets()で読み込む。データセットはImbalanced dataset for benchmarking: zenodoからダウンロードされる。

from imblearn.datasets import fetch_datasets

datasets = fetch_datasets()
datasets
OrderedDict([('ecoli',
              {'DESCR': 'ecoli',
               'data': array([[0.49, 0.29, 0.48, ..., 0.56, 0.24, 0.35],
                      [0.07, 0.4 , 0.48, ..., 0.54, 0.35, 0.44],
                      [0.56, 0.4 , 0.48, ..., 0.49, 0.37, 0.46],
                      ...,
                      [0.61, 0.6 , 0.48, ..., 0.44, 0.39, 0.38],
                      [0.59, 0.61, 0.48, ..., 0.42, 0.42, 0.37],
                      [0.74, 0.74, 0.48, ..., 0.31, 0.53, 0.52]]),
               'target': array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
                       1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
                       1,  1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])}),

             ....

             ('protein_homo',
              {'DESCR': 'protein_homo',
               'data': array([[ 52.  ,  32.69,   0.3 , ...,  -0.35,   0.26,   0.76],
                      [ 58.  ,  33.33,   0.  , ...,   1.16,   0.39,   0.73],
                      [ 77.  ,  27.27,  -0.91, ...,  -0.76,   0.26,   0.24],
                      ...,
                      [100.  ,  71.76,  41.92, ...,   3.41,   0.44,   0.78],
                      [ 85.65,  26.46,   1.85, ...,   2.88,   0.54,   0.77],
                      [ 87.5 ,  29.33,   5.84, ...,  -0.58,   0.16,   0.23]]),
               'target': array([-1, -1, -1, ...,  1, -1,  1])}),
             ('abalone_19',
              {'DESCR': 'abalone_19',
               'data': array([[0.    , 0.    , 1.    , ..., 0.2245, 0.101 , 0.15  ],
                      [0.    , 0.    , 1.    , ..., 0.0995, 0.0485, 0.07  ],
                      [1.    , 0.    , 0.    , ..., 0.2565, 0.1415, 0.21  ],
                      ...,
                      [0.    , 0.    , 1.    , ..., 0.5255, 0.2875, 0.308 ],
                      [1.    , 0.    , 0.    , ..., 0.531 , 0.261 , 0.296 ],
                      [0.    , 0.    , 1.    , ..., 0.9455, 0.3765, 0.495 ]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1])})])

データセットの情報は下記の表にまとめてある。Nameはデータセットの名前であり、個々のデータセットにアクセスするときに使用する。Repositoryはデータセットの配布場所である。中には多クラス分類問題のデータセットが含まれているが、二クラス分類問題となるように加工されている。Targetは少数クラスのカテゴリを指している。Ratioは多数クラスと少数クラス間のサンプル数の比である。これは、不均衡の度合いを示す。#Sはサンプル数であり、#Fは特徴数である。カテゴリ変数がone-hot表現に変換されていることに注意する。

ID Name Repository & Target Ratio #S #F
1 ecoli UCI, target: imU 8.6:1 336 7
2 optical_digits UCI, target: 8 9.1:1 5,620 64
3 satimage UCI, target: 4 9.3:1 6,435 36
4 pen_digits UCI, target: 5 9.4:1 10,992 16
5 abalone UCI, target: 7 9.7:1 4,177 10
6 sick_euthyroid UCI, target: sick euthyroid 9.8:1 3,163 42
7 spectrometer UCI, target: >=44 11:1 531 93
8 car_eval_34 UCI, target: good, v good 12:1 1,728 21
9 isolet UCI, target: A, B 12:1 7,797 617
10 us_crime UCI, target: >0.65 12:1 1,994 100
11 yeast_ml8 LIBSVM, target: 8 13:1 2,417 103
12 scene LIBSVM, target: >one label 13:1 2,407 294
13 libras_move UCI, target: 1 14:1 360 90
14 thyroid_sick UCI, target: sick 15:1 3,772 52
15 coil_2000 KDD, CoIL, target: minority 16:1 9,822 85
16 arrhythmia UCI, target: 06 17:1 452 278
17 solar_flare_m0 UCI, target: M->0 19:1 1,389 32
18 oil UCI, target: minority 22:1 937 49
19 car_eval_4 UCI, target: vgood 26:1 1,728 21
20 wine_quality UCI, wine, target: <=4 26:1 4,898 11
21 letter_img UCI, target: Z 26:1 20,000 16
22 yeast_me2 UCI, target: ME2 28:1 1,484 8
23 webpage LIBSVM, w7a, target: minority 33:1 34,780 300
24 ozone_level UCI, ozone, data 34:1 2,536 72
25 mammography UCI, target: minority 42:1 11,183 6
26 protein_homo KDD CUP 2004, minority 111:1 145,751 74
27 abalone_19 UCI, target: 19 130:1 4,177 10

ベースラインモデル

ベースラインモデルにはLightGBMを使用する。データセットを読み込んだ時点で最低限の前処理が施されているため、前処理を行わずともLightGBMを訓練することが可能である。評価指標には正解率、F1値、適合率、再現率を用いて、層化10分割交差検証を行う。これをsklearn.model_selection.cross_validata()で実装する。実行時にUndefinedMetricWarningが出る場合は、おそらく陽性を検出できていないことを示している。つまり、すべての予測が陰性となっていることを示している。そのため、適合率、再現率、F1値は0になっている。

from sklearn.model_selection import StratifiedKFold, cross_validate

dataset_scores = {}
scoring = ["accuracy", "f1", "precision", "recall"]
for dataset_name in datasets:
    kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)
    scores = cross_validate(
        lgb.LGBMClassifier(),
        datasets[dataset_name].data,
        datasets[dataset_name].target,
        scoring=scoring,
        cv=kfold
    )
    dataset_scores[dataset_name] = scores

他の評価指標を検討したい場合は、まずはsklearn.metrics.SCORERSを参照するとよい。

sklearn.metrics.SCORERS
{'accuracy': make_scorer(accuracy_score),
 'adjusted_mutual_info_score': make_scorer(adjusted_mutual_info_score),
 'adjusted_rand_score': make_scorer(adjusted_rand_score),
 'average_precision': make_scorer(average_precision_score, needs_threshold=True),
 'balanced_accuracy': make_scorer(balanced_accuracy_score),
 'completeness_score': make_scorer(completeness_score),
 'explained_variance': make_scorer(explained_variance_score),
 'f1': make_scorer(f1_score, average=binary),
 'f1_macro': make_scorer(f1_score, pos_label=None, average=macro),
 'f1_micro': make_scorer(f1_score, pos_label=None, average=micro),
 'f1_samples': make_scorer(f1_score, pos_label=None, average=samples),
 'f1_weighted': make_scorer(f1_score, pos_label=None, average=weighted),
 'fowlkes_mallows_score': make_scorer(fowlkes_mallows_score),
 'homogeneity_score': make_scorer(homogeneity_score),
 'jaccard': make_scorer(jaccard_score, average=binary),
 'jaccard_macro': make_scorer(jaccard_score, pos_label=None, average=macro),
 'jaccard_micro': make_scorer(jaccard_score, pos_label=None, average=micro),
 'jaccard_samples': make_scorer(jaccard_score, pos_label=None, average=samples),
 'jaccard_weighted': make_scorer(jaccard_score, pos_label=None, average=weighted),
 'max_error': make_scorer(max_error, greater_is_better=False),
 'mutual_info_score': make_scorer(mutual_info_score),
 'neg_brier_score': make_scorer(brier_score_loss, greater_is_better=False, needs_proba=True),
 'neg_log_loss': make_scorer(log_loss, greater_is_better=False, needs_proba=True),
 'neg_mean_absolute_error': make_scorer(mean_absolute_error, greater_is_better=False),
 'neg_mean_absolute_percentage_error': make_scorer(mean_absolute_percentage_error, greater_is_better=False),
 'neg_mean_gamma_deviance': make_scorer(mean_gamma_deviance, greater_is_better=False),
 'neg_mean_poisson_deviance': make_scorer(mean_poisson_deviance, greater_is_better=False),
 'neg_mean_squared_error': make_scorer(mean_squared_error, greater_is_better=False),
 'neg_mean_squared_log_error': make_scorer(mean_squared_log_error, greater_is_better=False),
 'neg_median_absolute_error': make_scorer(median_absolute_error, greater_is_better=False),
 'neg_root_mean_squared_error': make_scorer(mean_squared_error, greater_is_better=False, squared=False),
 'normalized_mutual_info_score': make_scorer(normalized_mutual_info_score),
 'precision': make_scorer(precision_score, average=binary),
 'precision_macro': make_scorer(precision_score, pos_label=None, average=macro),
 'precision_micro': make_scorer(precision_score, pos_label=None, average=micro),
 'precision_samples': make_scorer(precision_score, pos_label=None, average=samples),
 'precision_weighted': make_scorer(precision_score, pos_label=None, average=weighted),
 'r2': make_scorer(r2_score),
 'rand_score': make_scorer(rand_score),
 'recall': make_scorer(recall_score, average=binary),
 'recall_macro': make_scorer(recall_score, pos_label=None, average=macro),
 'recall_micro': make_scorer(recall_score, pos_label=None, average=micro),
 'recall_samples': make_scorer(recall_score, pos_label=None, average=samples),
 'recall_weighted': make_scorer(recall_score, pos_label=None, average=weighted),
 'roc_auc': make_scorer(roc_auc_score, needs_threshold=True),
 'roc_auc_ovo': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovo),
 'roc_auc_ovo_weighted': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovo, average=weighted),
 'roc_auc_ovr': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovr),
 'roc_auc_ovr_weighted': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovr, average=weighted),
 'top_k_accuracy': make_scorer(top_k_accuracy_score, needs_threshold=True),
 'v_measure_score': make_scorer(v_measure_score)}

今回は、この中のaccuracy, f1, precision, recallを選択した。

dataset_scoresには評価結果だけでなくモデルの訓練時間も含まれる。例えば、ecoliデータセットの交差検証の結果を確認する。

dataset_scores["ecoli"]
{'fit_time': array([0.10717535, 0.0581913 , 0.04513001, 0.07077527, 0.03441525,
        0.10235572, 0.04713297, 0.09230876, 0.05307102, 0.03584218]),
 'score_time': array([0.00427794, 0.0035553 , 0.01027751, 0.0081737 , 0.00449729,
        0.00350666, 0.01042509, 0.00973129, 0.0230062 , 0.02368474]),
 'test_accuracy': array([1.        , 0.88235294, 0.94117647, 0.94117647, 0.88235294,
        0.97058824, 0.90909091, 0.93939394, 0.90909091, 0.93939394]),
 'test_f1': array([1.        , 0.5       , 0.66666667, 0.66666667, 0.5       ,
        0.85714286, 0.57142857, 0.75      , 0.        , 0.66666667]),
 'test_precision': array([1.        , 0.5       , 1.        , 1.        , 0.5       ,
        1.        , 0.5       , 0.6       , 0.        , 0.66666667]),
 'test_recall': array([1.        , 0.5       , 0.5       , 0.5       , 0.5       ,
        0.75      , 0.66666667, 1.        , 0.        , 0.66666667])}

各フォールドの訓練時間は"fit_time"でアクセスできる。

print(f'各フォールドの訓練時間\n{dataset_scores["ecoli"]["fit_time"]}')
print(f'モデル訓練時間の平均\n{dataset_scores["ecoli"]["fit_time"].mean():.4f}')
各フォールドの訓練時間
[0.10717535 0.0581913  0.04513001 0.07077527 0.03441525 0.10235572
 0.04713297 0.09230876 0.05307102 0.03584218]
モデル訓練時間の平均
0.0646秒

dataset_scoresをpandasのデータフレームに変換して見やすくする。

scores_dict = {}
for dataset_name in datasets:
    scores_dict[dataset_name] = \
        ["{:.4f} ± {:.4f}".format(
            dataset_scores[dataset_name]["test_"+score_name].mean(),
            dataset_scores[dataset_name]["test_"+score_name].std()
        ) for score_name in scoring]

pd.DataFrame(scores_dict, index=scoring).T
accuracy f1 precision recall
ecoli 0.9315 ± 0.0352 0.6179 ± 0.2529 0.6767 ± 0.3124 0.6083 ± 0.2740
optical_digits 0.9909 ± 0.0045 0.9516 ± 0.0255 0.9903 ± 0.0098 0.9169 ± 0.0440
satimage 0.9500 ± 0.0062 0.7118 ± 0.0453 0.8055 ± 0.0324 0.6402 ± 0.0636
pen_digits 0.9978 ± 0.0017 0.9885 ± 0.0090 0.9962 ± 0.0063 0.9811 ± 0.0184
abalone 0.8911 ± 0.0100 0.1980 ± 0.0558 0.3272 ± 0.0860 0.1460 ± 0.0490
sick_euthyroid 0.9788 ± 0.0072 0.8823 ± 0.0428 0.9012 ± 0.0280 0.8674 ± 0.0725
spectrometer 0.9718 ± 0.0211 0.8150 ± 0.1340 0.9000 ± 0.1342 0.7800 ± 0.2040
car_eval_34 0.9936 ± 0.0055 0.9597 ± 0.0341 0.9531 ± 0.0575 0.9703 ± 0.0484
isolet 0.9859 ± 0.0032 0.9040 ± 0.0232 0.9460 ± 0.0213 0.8667 ± 0.0394
us_crime 0.9408 ± 0.0092 0.4881 ± 0.1006 0.7169 ± 0.1391 0.3867 ± 0.1147
yeast_ml8 0.9259 ± 0.0021 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000
scene 0.9273 ± 0.0053 0.1051 ± 0.0982 0.4067 ± 0.3969 0.0624 ± 0.0592
libras_move 0.9750 ± 0.0194 0.7600 ± 0.1837 0.9667 ± 0.1000 0.6833 ± 0.2734
thyroid_sick 0.9899 ± 0.0050 0.9154 ± 0.0427 0.9411 ± 0.0402 0.8918 ± 0.0522
coil_2000 0.9347 ± 0.0026 0.0636 ± 0.0398 0.2162 ± 0.1317 0.0375 ± 0.0237
arrhythmia 0.9691 ± 0.0226 0.6800 ± 0.2914 0.7000 ± 0.3055 0.6667 ± 0.2887
solar_flare_m0 0.9417 ± 0.0134 0.1175 ± 0.1188 0.2119 ± 0.2222 0.0881 ± 0.0965
oil 0.9626 ± 0.0161 0.3960 ± 0.2842 0.5833 ± 0.4167 0.3100 ± 0.2289
car_eval_4 1.0000 ± 0.0000 1.0000 ± 0.0000 1.0000 ± 0.0000 1.0000 ± 0.0000
wine_quality 0.9657 ± 0.0034 0.3210 ± 0.0588 0.6472 ± 0.1637 0.2181 ± 0.0467
letter_img 0.9981 ± 0.0010 0.9737 ± 0.0134 0.9876 ± 0.0152 0.9605 ± 0.0222
yeast_me2 0.9623 ± 0.0091 0.2528 ± 0.1844 0.4250 ± 0.3622 0.2000 ± 0.1549
webpage 0.9884 ± 0.0015 0.7671 ± 0.0330 0.8867 ± 0.0294 0.6769 ± 0.0415
ozone_level 0.9708 ± 0.0061 0.1728 ± 0.1936 0.3500 ± 0.3471 0.1268 ± 0.1622
mammography 0.9880 ± 0.0024 0.6936 ± 0.0651 0.8527 ± 0.0698 0.5885 ± 0.0789
protein_homo 0.9976 ± 0.0005 0.8509 ± 0.0310 0.9274 ± 0.0254 0.7870 ± 0.0427
abalone_19 0.9919 ± 0.0012 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000

データセットの中には、陽性を全く検出できていないものがある。今回はyeast_ml8, abalone_19の2つのデータセットが、陽性を全く検出できていない。yeast_ml8, abalone_19はF1値、適合率、再現率が0である。

一方、陽性と陰性の分類が完璧なデータセットがある。car_eval_4は陽性と陰性を全て正しく分類している。正解率、F1値、適合率、再現率のすべてが1.0である。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2