1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

サポートベクタマシン(load_iris)

Last updated at Posted at 2020-10-06

今回はサポートベクタマシン(分類)の実装をまとめていきます。

※以降、SVMと省略することがあります。

##■ サポートベクタマシンの手順

次の7つのSTEPで進めます。

  1. モジュールの用意
  2. データの準備
  3. データの可視化
  4. モデルの作成
  5. モデルのプロット
  6. 分類を予測
  7. モデルの評価

##1. モジュールの用意
最初に、必要なモジュールをインポートしておきます。


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# データセットを読み込むモジュール
from sklearn.datasets import load_iris

# 標準化(分散正規化)を行うモジュール
from sklearn.preprocessing import StandardScaler

# 訓練データとテストデータを分割するモジュール
from sklearn.model_selection import train_test_split

# サポートベクトルマシンを実行するモジュール
from sklearn.svm import SVC

# 分類の評価を行うモジュール
from sklearn.metrics import classification_report

##2. データの準備
今回はirisデータセットを使って、二値分類をしていきます。

最初にデータの取得をし、標準化を行ってから分割します。


# irisデータセットの読み込み
iris = load_iris()

# 目的変数と説明変数に分ける
X, y = iris.data[:100, [0, 2]], iris.target[:100]

# 標準化(分散正規化)
std = StandardScaler()
X = std.fit_transform(X)

# 訓練データとテストデータに分割する
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

二値分類を行うために、データセットを100行目まで(Setosa・Versicolor のみ)と指定しています。
またプロットしやすくするために、説明変数も2つに絞っています。(Sepal Length・Petal Lengh のみ)

標準化は、例えば2桁と4桁の特徴量(説明変数)があった際に、後者の影響が大きくなってしまうため
全ての特徴量に対して平均を0・分散を1にして、スケールを揃えています。


##3. データの可視化
SVMで分類をする前のデータをプロットして見ておきます。


# 描画オブジェクトとサブプロットの作成
fig, ax = plt.subplots()

# Setosa のプロット
ax.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], 
           marker = 'o', label = 'Setosa')

# Versicolor のプロット
ax.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1],
           marker = 'x', label = 'Versicolor')

# 軸ラベルの設定
ax.set_xlabel('Sepal Length')
ax.set_ylabel('Petal Length')

# 凡例の設定 
ax.legend(loc = 'best')

plt.show()

Setosa (y_train ==0) に対応する特徴量(0:Sepal Lengh を横軸, 1: Petal Length を縦軸)でプロット
Versicolor (y_train ==1) に対応する特徴量(0:Sepal Lengh を横軸, 1: Petal Length を縦軸)でプロット

<出力結果>
image.png
##4. モデルの作成
最初にSVMの実行関数(インスタンス)を作成し、訓練データに当てはめます。


# インスタンスを作成
svc = SVC(kernel = 'linear', C = 1e6)
    
# 訓練データからモデルを作成
svc.fit(X_train, y_train)

今回はすでに線形分離(1本の直線で分ける)が可能なので、引数を kernel = 'linear' に設定しています。

また C はハイパーパラメータで、出力数値やプロットを見ながら自身で調整していくものです。


##5. モデルのプロット
訓練データからサポートベクタマシンのモデルが作成できたので
どのように分類ができているのか、プロットして確認します。

前半は、先ほどの散布図のコードと全く同じです。


# 描画オブジェクトとサブプロットの作成
fig, ax = plt.subplots()

# Setosa のプロット
ax.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], 
           marker = 'o', label = 'Setosa')

# Versicolor のプロット
ax.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1],
           marker = 'x', label = 'Versicolor')

ax.set_xlabel('Sepal Length')
ax.set_ylabel('Petal Length')
    
ax.legend(loc = 'upper left')

# ここから下は、他データの場合でもそのままペーストして毎回流用できます。(数値の微調整は必要)

# 決定境界(直線)のプロット範囲を指定
xmin = -2.0
xmax = 2.5
ymin = -1.5
ymax = 1.8

# 決定境界とマージンをプロット
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 100), np.linspace(ymin, ymax, 100))
xy = np.vstack([xx.ravel(), yy.ravel()]).T
p = svc.decision_function(xy).reshape(100, 100)  
ax.contour(xx, yy, p, colors = 'k', levels = [-1, 0, 1], alpha = 1, 
           linestyles = ['--', '-', '--'])

# サポートベクタをプロット
ax.scatter(svc.support_vectors_[:, 0], svc.support_vectors_[:, 1],
           s = 250, facecolors = 'none', edgecolors = 'black')
    
plt.show()

alpha:直線の濃さ
s:サポートベクタ(○)の大きさ

<出力結果>
image.png
##6. 分類を予測
モデルが完成したので、分類の予測をしていきます。

# 分類結果を予測する
y_pred = svc.predict(X_test)

# 予測値と正解値を出力
print(y_pred)
print(y_test)

<出力結果>


# 予測値と正解値を比較してみる
y_pred: [0 1 1 1 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 1 1 0 0 1 0 0 1 1 1 1]
y_test: [0 1 1 1 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 1 1 0 0 1 0 0 1 1 1 1]

0:Setosa 1:Versicolor

今回の場合は全て一致(正解)していることが分かります。


##7. モデルの評価
今回は分類(二値分類)となるので、混同行列を用いた適合率・再現率・F値で評価を行います。


# 適合率、再現率、F値を出力
print(classification_report(y_test, y_pred))

<出力結果>
image.png

以上より、Setosa と Versicolor における分類の評価を行うことができました。


##■ 最後に
SVMでは上記1~7の手順をもとに、モデルの作成・評価を行っていきます。

今回は初学者の方向けに、実装(コード)のみまとめさせていただきましたが
今後タイミングを見て、理論(数式)についても記事を作成していければと思います。

ご精読いただき、ありがとうございました。

参考文献:Pythonによるあたらしいデータ分析の教科書
    (Python 3 エンジニア認定データ分析試験 主教材)

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?