14
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

機械学習ツールを掘り下げる by 日経 xTECH ビジネスAI③Advent Calendar 2019

Day 11

Light GBMの多クラス分類/二値分類を可視化するConfusion Matrix Plot

Last updated at Posted at 2019-12-11

Introduction

Summary

・scikit-learn 0.22のアップデートであるconfusion matrixのplotを試してみた
・従来のscikit-learnのconfusion matrixはarrayで出力していたのでグラフ映えしなかった
・LightGBMを用いる場合はファンクションの引数の分類モデルを
 ___scikit-learnインターフェース___で学習させないとできなかった
 scikit-learnの呪い?
・plotするメリットとしては、ファンクション内で__y_predを勝手に計算してくれる__ので
 ちょっと楽なのと見やすくて映える

Dataset

みんな大好きiris。目的変数はもちろんspecies。
image.png

#Body

準備

まずはアップデートを忘れずに。

ライブラリのアップデート
!pip install --upgrade scikit-learn

下記の___plot_confusion_matrix__が今回のテーマのメイン。

今回必要なライブラリのimport
#default
import numpy as np
import pandas as pd
import seaborn as sns

#for_modeling
import lightgbm as lgb

#for_plot
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
from sklearn.model_selection import train_test_split

irisデータセットはseabornライブラリからimportします。

データセットの整理
df = sns.load_dataset('iris')

X=df.drop(["species"],axis=1)
y=df["species"]

SEED=5
X_train_tmp, X_test,y_train_tmp,y_test=train_test_split(X, y, test_size=0.2, random_state=SEED)
X_train, X_val, y_train, y_val=train_test_split(X_train_tmp, y_train_tmp, test_size=0.2, random_state=SEED)

##オリジナルインターフェースで学習させたパターン
Summaryで書いた通り失敗するのだが、
まずはLightGBMモデルをオリジナルのインターフェースで記述したケースで実行してみる。

オリジナルインターフェースでの学習
#失敗例
lgb_params = {
    'objective':'multiclass',
    'n_estimators':1000,
    'seed': SEED,
    'early_stopping_rounds':100
} 

lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_val, y_val)

model = lgb.train(
            lgb_params,
            lgb_train,
            valid_sets=lgb_eval,
            verbose_eval = 15
            )

confusion matrixを出力してみる。
引数は下記の通りで、estimatorに学習させたモデル、Xに検証部分の特徴量(X_test)、Yに正解データ(Y_test)を入れることで出力される。
y_predは自動的に計算されるので宣言が必要ない。

sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)

引用元:scikit-learn公式ドキュメント

プロット出力
plot_confusion_matrix(model,X_test,y_test)

すると下記のエラーが返ってくる。
estimatorに指定しているモデルが分類モデルと認識されていない様。

エラー文
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-125-5b518d7bfee9> in <module>()
----> 1 plot_confusion_matrix(model,X_test,y_test)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_plot/confusion_matrix.py in plot_confusion_matrix(estimator, X, y_true, labels, sample_weight, normalize, display_labels, include_values, xticks_rotation, values_format, cmap, ax)
    183 
    184     if not is_classifier(estimator):
--> 185         raise ValueError("plot_confusion_matrix only supports classifiers")
    186 
    187     if normalize not in {'true', 'pred', 'all', None}:

ValueError: plot_confusion_matrix only supports classifiers

たしかLightGBMをscikit-learnインターフェースで記述する場合であれば
__LGBMClassifier__といういかにも分類モデルらしいLightGBMの記述方法があったことを思い出す。
試しにその記述方法で試してみる。

scikit-learnインターフェースで学習させたパターン

scikit-learnインターフェースでの学習
model2 = lgb.LGBMClassifier(objective='multiclass',
                        n_estimators=1000,
                        seed=SEED,
                        early_stopping_rounds=100)
model2.fit(X_train, y_train,
        eval_set=[(X_val, y_val)],
        verbose=50)

model2で再トライしてみる。

プロット再出力
plot_confusion_matrix(model2,X_test,y_test)

出た。
seabornでのplotのようにデフォルトで見やすいプロットが出力された。

image.png

タイトルをつける場合は下記のように記載すれば出力される。

タイトルを追加
disp=plot_confusion_matrix(model2,X_test,y_test)
disp.ax_.set_title("Confusion Matrix")
plt.show()

image.png

ちなみにコードは割愛するがtitanicの二値分類でも全く同様にして出力できた。
逆に従来の__confusion_matrix.confusion_matrix__の場合は二値分類しかできず、多クラス分類だとできなかった。

image.png

Conclusion

・scikit-learnの新機能のconfusion matrix plotを用いることでかなりシンプルな記載で描画することができた。
・LightGBMの場合は現状、scikit-learnインターフェースでの学習させた方が良い。
・従来の__confusion_matrix.confusion_matrix__とは異なり、二値分類でも多クラス分類でも対応している。

14
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?