6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

sklearnでも特徴量の重要度を見たい

Posted at

概要

lightgbmには特徴量の重要度を出すplot_importanceという関数がある。
Python: LightGBM を使ってみる

Python: LightGBM を使ってみる より

sklearnでも特徴量の重要度を可視化したい、という気持ちになるのでやります。

ソースコード

みんな大好きirisデータセットを使っていきます。

from pprint import pprint

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier

iris = load_iris()

irisの情報を見る

irisについてもう少し情報を得ましょう。
irisは3クラス、特徴量は4種類あることがわかります。

pprint(iris)
{'DESCR': '.. _iris_dataset:\n'
          '\n'
          'Iris plants dataset\n'
          '--------------------\n'
          '\n'
          '**Data Set Characteristics:**\n'
          '\n'
          '    :Number of Instances: 150 (50 in each of three classes)\n'
          '    :Number of Attributes: 4 numeric, predictive attributes and the '
          'class\n'
          '    :Attribute Information:\n'
          '        - sepal length in cm\n'
          '        - sepal width in cm\n'
          '        - petal length in cm\n'
          '        - petal width in cm\n'
          '        - class:\n'
          '                - Iris-Setosa\n'
          '                - Iris-Versicolour\n'
          '                - Iris-Virginica\n'
          '                \n'
          '    :Summary Statistics:\n'
          '\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          '                    Min  Max   Mean    SD   Class Correlation\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          '    sepal length:   4.3  7.9   5.84   0.83    0.7826\n'
          '    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n'
          '    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n'
          '    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          

irisのデータを学習する

特徴の名前やデータを取得しておきましょう。
なにも考えずにfitします。

feature_names = iris.feature_names
data = iris.data
target = iris.target
target_names = iris.target_names

clf = SGDClassifier()
clf.fit(data, target)

重みを確認する

重みはcoef_を見ることでわかります。
shapeは(n_classes, n_features)になっています。

print(clf.coef_)

array([[  9.76085896,  30.74670571, -48.80429478, -27.33040508],
       [  9.51335529, -73.66752043,  28.90596414, -68.17904623],
       [-89.68307485, -84.96291301, 151.2699483 , 135.31130591]])

特徴量の重要度を描画する

ここまできたら後は描画するだけです。
出力結果は以下の画像です。

petal lengthとpetal widthが重要な変数であることがわかります。
これは、先ほどみたirisの情報の中でClass Correlationが高いと書かれていることからもわかると思います。

df = pd.DataFrame(clf.coef_.T, columns=target_names, index=name)

df.plot.bar()

coef.png

まとめ

  • coef_を見ることで特徴量の重要度がわかる
  • pandasを使うと見やすく簡単に描画できる
6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?