不均衡データセットのベースラインモデル評価
Imbalanced-learnに実装されている27の不均衡なデータセットを使用する。ここでは、以下のことを行う。
- 27のデータセットの読み込み
- LightGBMによるベースラインモデルの評価
実行環境
このnotebookはGoogle Colaboratoryで実行されている。
!python --version
Python 3.7.13
ライブラリ
使用しているライブラリは以下の4つ。
import imblearn
import sklearn
import pandas as pd
import lightgbm as lgb
print(f"Imbalanced-learn version {imblearn.__version__}")
print(f"Scikit-learn version {sklearn.__version__}")
print(f"Pandas version {pd.__version__}")
print(f"LightGBM version {lgb.__version__}")
Imbalanced-learn version 0.8.1
Scikit-learn version 1.0.2
Pandas version 1.3.5
LightGBM version 2.2.3
データセットの読み込み
Imbalanced-learnのベンチマークデータセットを使用する。27のデータセットはimblearn.datasets.fetch_datasets()
で読み込む。データセットはImbalanced dataset for benchmarking: zenodoからダウンロードされる。
from imblearn.datasets import fetch_datasets
datasets = fetch_datasets()
datasets
OrderedDict([('ecoli',
{'DESCR': 'ecoli',
'data': array([[0.49, 0.29, 0.48, ..., 0.56, 0.24, 0.35],
[0.07, 0.4 , 0.48, ..., 0.54, 0.35, 0.44],
[0.56, 0.4 , 0.48, ..., 0.49, 0.37, 0.46],
...,
[0.61, 0.6 , 0.48, ..., 0.44, 0.39, 0.38],
[0.59, 0.61, 0.48, ..., 0.42, 0.42, 0.37],
[0.74, 0.74, 0.48, ..., 0.31, 0.53, 0.52]]),
'target': array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])}),
....
('protein_homo',
{'DESCR': 'protein_homo',
'data': array([[ 52. , 32.69, 0.3 , ..., -0.35, 0.26, 0.76],
[ 58. , 33.33, 0. , ..., 1.16, 0.39, 0.73],
[ 77. , 27.27, -0.91, ..., -0.76, 0.26, 0.24],
...,
[100. , 71.76, 41.92, ..., 3.41, 0.44, 0.78],
[ 85.65, 26.46, 1.85, ..., 2.88, 0.54, 0.77],
[ 87.5 , 29.33, 5.84, ..., -0.58, 0.16, 0.23]]),
'target': array([-1, -1, -1, ..., 1, -1, 1])}),
('abalone_19',
{'DESCR': 'abalone_19',
'data': array([[0. , 0. , 1. , ..., 0.2245, 0.101 , 0.15 ],
[0. , 0. , 1. , ..., 0.0995, 0.0485, 0.07 ],
[1. , 0. , 0. , ..., 0.2565, 0.1415, 0.21 ],
...,
[0. , 0. , 1. , ..., 0.5255, 0.2875, 0.308 ],
[1. , 0. , 0. , ..., 0.531 , 0.261 , 0.296 ],
[0. , 0. , 1. , ..., 0.9455, 0.3765, 0.495 ]]),
'target': array([-1, -1, -1, ..., -1, -1, -1])})])
データセットの情報は下記の表にまとめてある。Nameはデータセットの名前であり、個々のデータセットにアクセスするときに使用する。Repositoryはデータセットの配布場所である。中には多クラス分類問題のデータセットが含まれているが、二クラス分類問題となるように加工されている。Targetは少数クラスのカテゴリを指している。Ratioは多数クラスと少数クラス間のサンプル数の比である。これは、不均衡の度合いを示す。#Sはサンプル数であり、#Fは特徴数である。カテゴリ変数がone-hot表現に変換されていることに注意する。
ID | Name | Repository & Target | Ratio | #S | #F |
---|---|---|---|---|---|
1 | ecoli | UCI, target: imU | 8.6:1 | 336 | 7 |
2 | optical_digits | UCI, target: 8 | 9.1:1 | 5,620 | 64 |
3 | satimage | UCI, target: 4 | 9.3:1 | 6,435 | 36 |
4 | pen_digits | UCI, target: 5 | 9.4:1 | 10,992 | 16 |
5 | abalone | UCI, target: 7 | 9.7:1 | 4,177 | 10 |
6 | sick_euthyroid | UCI, target: sick euthyroid | 9.8:1 | 3,163 | 42 |
7 | spectrometer | UCI, target: >=44 | 11:1 | 531 | 93 |
8 | car_eval_34 | UCI, target: good, v good | 12:1 | 1,728 | 21 |
9 | isolet | UCI, target: A, B | 12:1 | 7,797 | 617 |
10 | us_crime | UCI, target: >0.65 | 12:1 | 1,994 | 100 |
11 | yeast_ml8 | LIBSVM, target: 8 | 13:1 | 2,417 | 103 |
12 | scene | LIBSVM, target: >one label | 13:1 | 2,407 | 294 |
13 | libras_move | UCI, target: 1 | 14:1 | 360 | 90 |
14 | thyroid_sick | UCI, target: sick | 15:1 | 3,772 | 52 |
15 | coil_2000 | KDD, CoIL, target: minority | 16:1 | 9,822 | 85 |
16 | arrhythmia | UCI, target: 06 | 17:1 | 452 | 278 |
17 | solar_flare_m0 | UCI, target: M->0 | 19:1 | 1,389 | 32 |
18 | oil | UCI, target: minority | 22:1 | 937 | 49 |
19 | car_eval_4 | UCI, target: vgood | 26:1 | 1,728 | 21 |
20 | wine_quality | UCI, wine, target: <=4 | 26:1 | 4,898 | 11 |
21 | letter_img | UCI, target: Z | 26:1 | 20,000 | 16 |
22 | yeast_me2 | UCI, target: ME2 | 28:1 | 1,484 | 8 |
23 | webpage | LIBSVM, w7a, target: minority | 33:1 | 34,780 | 300 |
24 | ozone_level | UCI, ozone, data | 34:1 | 2,536 | 72 |
25 | mammography | UCI, target: minority | 42:1 | 11,183 | 6 |
26 | protein_homo | KDD CUP 2004, minority | 111:1 | 145,751 | 74 |
27 | abalone_19 | UCI, target: 19 | 130:1 | 4,177 | 10 |
ベースラインモデル
ベースラインモデルにはLightGBMを使用する。データセットを読み込んだ時点で最低限の前処理が施されているため、前処理を行わずともLightGBMを訓練することが可能である。評価指標には正解率、F1値、適合率、再現率を用いて、層化10分割交差検証を行う。これをsklearn.model_selection.cross_validata()
で実装する。実行時にUndefinedMetricWarning
が出る場合は、おそらく陽性を検出できていないことを示している。つまり、すべての予測が陰性となっていることを示している。そのため、適合率、再現率、F1値は0になっている。
from sklearn.model_selection import StratifiedKFold, cross_validate
dataset_scores = {}
scoring = ["accuracy", "f1", "precision", "recall"]
for dataset_name in datasets:
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)
scores = cross_validate(
lgb.LGBMClassifier(),
datasets[dataset_name].data,
datasets[dataset_name].target,
scoring=scoring,
cv=kfold
)
dataset_scores[dataset_name] = scores
他の評価指標を検討したい場合は、まずはsklearn.metrics.SCORERS
を参照するとよい。
sklearn.metrics.SCORERS
{'accuracy': make_scorer(accuracy_score),
'adjusted_mutual_info_score': make_scorer(adjusted_mutual_info_score),
'adjusted_rand_score': make_scorer(adjusted_rand_score),
'average_precision': make_scorer(average_precision_score, needs_threshold=True),
'balanced_accuracy': make_scorer(balanced_accuracy_score),
'completeness_score': make_scorer(completeness_score),
'explained_variance': make_scorer(explained_variance_score),
'f1': make_scorer(f1_score, average=binary),
'f1_macro': make_scorer(f1_score, pos_label=None, average=macro),
'f1_micro': make_scorer(f1_score, pos_label=None, average=micro),
'f1_samples': make_scorer(f1_score, pos_label=None, average=samples),
'f1_weighted': make_scorer(f1_score, pos_label=None, average=weighted),
'fowlkes_mallows_score': make_scorer(fowlkes_mallows_score),
'homogeneity_score': make_scorer(homogeneity_score),
'jaccard': make_scorer(jaccard_score, average=binary),
'jaccard_macro': make_scorer(jaccard_score, pos_label=None, average=macro),
'jaccard_micro': make_scorer(jaccard_score, pos_label=None, average=micro),
'jaccard_samples': make_scorer(jaccard_score, pos_label=None, average=samples),
'jaccard_weighted': make_scorer(jaccard_score, pos_label=None, average=weighted),
'max_error': make_scorer(max_error, greater_is_better=False),
'mutual_info_score': make_scorer(mutual_info_score),
'neg_brier_score': make_scorer(brier_score_loss, greater_is_better=False, needs_proba=True),
'neg_log_loss': make_scorer(log_loss, greater_is_better=False, needs_proba=True),
'neg_mean_absolute_error': make_scorer(mean_absolute_error, greater_is_better=False),
'neg_mean_absolute_percentage_error': make_scorer(mean_absolute_percentage_error, greater_is_better=False),
'neg_mean_gamma_deviance': make_scorer(mean_gamma_deviance, greater_is_better=False),
'neg_mean_poisson_deviance': make_scorer(mean_poisson_deviance, greater_is_better=False),
'neg_mean_squared_error': make_scorer(mean_squared_error, greater_is_better=False),
'neg_mean_squared_log_error': make_scorer(mean_squared_log_error, greater_is_better=False),
'neg_median_absolute_error': make_scorer(median_absolute_error, greater_is_better=False),
'neg_root_mean_squared_error': make_scorer(mean_squared_error, greater_is_better=False, squared=False),
'normalized_mutual_info_score': make_scorer(normalized_mutual_info_score),
'precision': make_scorer(precision_score, average=binary),
'precision_macro': make_scorer(precision_score, pos_label=None, average=macro),
'precision_micro': make_scorer(precision_score, pos_label=None, average=micro),
'precision_samples': make_scorer(precision_score, pos_label=None, average=samples),
'precision_weighted': make_scorer(precision_score, pos_label=None, average=weighted),
'r2': make_scorer(r2_score),
'rand_score': make_scorer(rand_score),
'recall': make_scorer(recall_score, average=binary),
'recall_macro': make_scorer(recall_score, pos_label=None, average=macro),
'recall_micro': make_scorer(recall_score, pos_label=None, average=micro),
'recall_samples': make_scorer(recall_score, pos_label=None, average=samples),
'recall_weighted': make_scorer(recall_score, pos_label=None, average=weighted),
'roc_auc': make_scorer(roc_auc_score, needs_threshold=True),
'roc_auc_ovo': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovo),
'roc_auc_ovo_weighted': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovo, average=weighted),
'roc_auc_ovr': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovr),
'roc_auc_ovr_weighted': make_scorer(roc_auc_score, needs_proba=True, multi_class=ovr, average=weighted),
'top_k_accuracy': make_scorer(top_k_accuracy_score, needs_threshold=True),
'v_measure_score': make_scorer(v_measure_score)}
今回は、この中のaccuracy
, f1
, precision
, recall
を選択した。
dataset_scores
には評価結果だけでなくモデルの訓練時間も含まれる。例えば、ecoliデータセットの交差検証の結果を確認する。
dataset_scores["ecoli"]
{'fit_time': array([0.10717535, 0.0581913 , 0.04513001, 0.07077527, 0.03441525,
0.10235572, 0.04713297, 0.09230876, 0.05307102, 0.03584218]),
'score_time': array([0.00427794, 0.0035553 , 0.01027751, 0.0081737 , 0.00449729,
0.00350666, 0.01042509, 0.00973129, 0.0230062 , 0.02368474]),
'test_accuracy': array([1. , 0.88235294, 0.94117647, 0.94117647, 0.88235294,
0.97058824, 0.90909091, 0.93939394, 0.90909091, 0.93939394]),
'test_f1': array([1. , 0.5 , 0.66666667, 0.66666667, 0.5 ,
0.85714286, 0.57142857, 0.75 , 0. , 0.66666667]),
'test_precision': array([1. , 0.5 , 1. , 1. , 0.5 ,
1. , 0.5 , 0.6 , 0. , 0.66666667]),
'test_recall': array([1. , 0.5 , 0.5 , 0.5 , 0.5 ,
0.75 , 0.66666667, 1. , 0. , 0.66666667])}
各フォールドの訓練時間は"fit_time"
でアクセスできる。
print(f'各フォールドの訓練時間\n{dataset_scores["ecoli"]["fit_time"]}')
print(f'モデル訓練時間の平均\n{dataset_scores["ecoli"]["fit_time"].mean():.4f}秒')
各フォールドの訓練時間
[0.10717535 0.0581913 0.04513001 0.07077527 0.03441525 0.10235572
0.04713297 0.09230876 0.05307102 0.03584218]
モデル訓練時間の平均
0.0646秒
dataset_scores
をpandasのデータフレームに変換して見やすくする。
scores_dict = {}
for dataset_name in datasets:
scores_dict[dataset_name] = \
["{:.4f} ± {:.4f}".format(
dataset_scores[dataset_name]["test_"+score_name].mean(),
dataset_scores[dataset_name]["test_"+score_name].std()
) for score_name in scoring]
pd.DataFrame(scores_dict, index=scoring).T
accuracy | f1 | precision | recall | |
---|---|---|---|---|
ecoli | 0.9315 ± 0.0352 | 0.6179 ± 0.2529 | 0.6767 ± 0.3124 | 0.6083 ± 0.2740 |
optical_digits | 0.9909 ± 0.0045 | 0.9516 ± 0.0255 | 0.9903 ± 0.0098 | 0.9169 ± 0.0440 |
satimage | 0.9500 ± 0.0062 | 0.7118 ± 0.0453 | 0.8055 ± 0.0324 | 0.6402 ± 0.0636 |
pen_digits | 0.9978 ± 0.0017 | 0.9885 ± 0.0090 | 0.9962 ± 0.0063 | 0.9811 ± 0.0184 |
abalone | 0.8911 ± 0.0100 | 0.1980 ± 0.0558 | 0.3272 ± 0.0860 | 0.1460 ± 0.0490 |
sick_euthyroid | 0.9788 ± 0.0072 | 0.8823 ± 0.0428 | 0.9012 ± 0.0280 | 0.8674 ± 0.0725 |
spectrometer | 0.9718 ± 0.0211 | 0.8150 ± 0.1340 | 0.9000 ± 0.1342 | 0.7800 ± 0.2040 |
car_eval_34 | 0.9936 ± 0.0055 | 0.9597 ± 0.0341 | 0.9531 ± 0.0575 | 0.9703 ± 0.0484 |
isolet | 0.9859 ± 0.0032 | 0.9040 ± 0.0232 | 0.9460 ± 0.0213 | 0.8667 ± 0.0394 |
us_crime | 0.9408 ± 0.0092 | 0.4881 ± 0.1006 | 0.7169 ± 0.1391 | 0.3867 ± 0.1147 |
yeast_ml8 | 0.9259 ± 0.0021 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
scene | 0.9273 ± 0.0053 | 0.1051 ± 0.0982 | 0.4067 ± 0.3969 | 0.0624 ± 0.0592 |
libras_move | 0.9750 ± 0.0194 | 0.7600 ± 0.1837 | 0.9667 ± 0.1000 | 0.6833 ± 0.2734 |
thyroid_sick | 0.9899 ± 0.0050 | 0.9154 ± 0.0427 | 0.9411 ± 0.0402 | 0.8918 ± 0.0522 |
coil_2000 | 0.9347 ± 0.0026 | 0.0636 ± 0.0398 | 0.2162 ± 0.1317 | 0.0375 ± 0.0237 |
arrhythmia | 0.9691 ± 0.0226 | 0.6800 ± 0.2914 | 0.7000 ± 0.3055 | 0.6667 ± 0.2887 |
solar_flare_m0 | 0.9417 ± 0.0134 | 0.1175 ± 0.1188 | 0.2119 ± 0.2222 | 0.0881 ± 0.0965 |
oil | 0.9626 ± 0.0161 | 0.3960 ± 0.2842 | 0.5833 ± 0.4167 | 0.3100 ± 0.2289 |
car_eval_4 | 1.0000 ± 0.0000 | 1.0000 ± 0.0000 | 1.0000 ± 0.0000 | 1.0000 ± 0.0000 |
wine_quality | 0.9657 ± 0.0034 | 0.3210 ± 0.0588 | 0.6472 ± 0.1637 | 0.2181 ± 0.0467 |
letter_img | 0.9981 ± 0.0010 | 0.9737 ± 0.0134 | 0.9876 ± 0.0152 | 0.9605 ± 0.0222 |
yeast_me2 | 0.9623 ± 0.0091 | 0.2528 ± 0.1844 | 0.4250 ± 0.3622 | 0.2000 ± 0.1549 |
webpage | 0.9884 ± 0.0015 | 0.7671 ± 0.0330 | 0.8867 ± 0.0294 | 0.6769 ± 0.0415 |
ozone_level | 0.9708 ± 0.0061 | 0.1728 ± 0.1936 | 0.3500 ± 0.3471 | 0.1268 ± 0.1622 |
mammography | 0.9880 ± 0.0024 | 0.6936 ± 0.0651 | 0.8527 ± 0.0698 | 0.5885 ± 0.0789 |
protein_homo | 0.9976 ± 0.0005 | 0.8509 ± 0.0310 | 0.9274 ± 0.0254 | 0.7870 ± 0.0427 |
abalone_19 | 0.9919 ± 0.0012 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
データセットの中には、陽性を全く検出できていないものがある。今回はyeast_ml8, abalone_19の2つのデータセットが、陽性を全く検出できていない。yeast_ml8, abalone_19はF1値、適合率、再現率が0である。
一方、陽性と陰性の分類が完璧なデータセットがある。car_eval_4は陽性と陰性を全て正しく分類している。正解率、F1値、適合率、再現率のすべてが1.0である。