0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

サポートベクトルマシン

Last updated at Posted at 2021-12-22

■サポートベクトルマシン
【特徴】
・線形2値分類手法
・外れ値の影響を受けにくい

どのように2値分類を行うのか?
・マージン:判別境界の周りにデータポイントが存在しない帯状の領域
・サポートベクトル:マージンの境界上のデータポイント
・マージンの大きさ:判別境界からサポートベクトルまでの距離(点と直線の距離)
上記の3点よりマージン=距離の大きさが最大となるように判別境界を設定します。

【点と直線の距離の公式】
点$(x_1, x_2)から直線w_1x_1 + w_2x_2 + b = 0$に下した垂線の長さを、点と直線間の距離といい下記の公式により求まります。

$距離 =\frac{|w_1x_1 + w_2x_2 + b|}{\sqrt{w_1^2 + w_2^2}}$

【イメージ】

スクリーンショット 2021-12-20 0.16.38.png

実装

######1.モジュールインポート

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

######2.データ読み込み&分割

import pandas as pd

df_wine = pd.read_csv('https://archive.ics.uci.edu/ml/'
                      'machine-learning-databases/wine/wine.data',
                      header=None)

df_wine.columns = ['Class label', 'Alcohol', 'Malic acid', 'Ash',
                   'Alcalinity of ash', 'Magnesium', 'Total phenols',
                   'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins',
                   'Color intensity', 'Hue', 'OD280/OD315 of diluted wines',
                   'Proline'] #13個



# classlabel 1を削除 
df_wine = df_wine[df_wine['Class label'] != 1]


y = df_wine['Class label'].values
X = df_wine[['Alcohol', 'OD280/OD315 of diluted wines']].values

######3.ラベル変更

df_wine_data = df_wine[['Alcohol', 'OD280/OD315 of diluted wines']]

df_wine_data['Class'] = df_wine['Class label'].values

df_wine_target = df_wine[['Class label']]

#データラベルを1と-1に変更
df_wine_data['Class'] = df_wine_data['Class'].apply(lambda x: 1 if x == 2 else -1)

df_wine_data

スクリーンショット 2021-12-22 22.26.26.png

######4.データの可視化

X, y = df_wine_data[['Alcohol','OD280/OD315 of diluted wines']], df_wine_data['Class']

x1 = X['Alcohol']
x2 = X['OD280/OD315 of diluted wines']

plt.scatter(x1, x2, c=y, cmap='autumn')
plt.show()

スクリーンショット 2021-12-22 22.27.14.png

######5.モデル学習

from sklearn.svm import SVC
svc = SVC()
svc.fit(X, y)

#予測
y_pred = svc.predict(X_test)

#正解率確認
accuracy_score(y_test, y_pred)
#1.0

#決定領域プロット
xx, yy = np.meshgrid(np.linspace(start=11, stop=15, num=100),
                     np.linspace(start=1, stop=4, num=100))
x_test = np.c_[xx.ravel(), yy.ravel()]

y_pred = svc.predict(x_test)
y_pred

スクリーンショット 2021-12-22 22.28.16.png

以上、モデルの作成・予測を行うことができました。今回は線形分離可能なモデルでしたが次回は非線形分類についても記事を書けたらと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?