皆さんこんばんわ。
12/19の文を書いていきます。
今日は用事があって遅くなりましたね。。。
今日は昨日勉強した直線SVCモデルの分類をまとめていこうと思います。
分類の目的
上の写真を見てください。
分類とはこの青の点とオレンジのバツを分ける線を自動的に見つけることです。
2次元だと人間にとってとても簡単ですよね?
SVCモデルは高次元の場合でも分ける為の式を求めれられるらしいので凄いですよね。
さぁ、これが何に役に立つの?と思った方も多いでしょう。
下の画像はいかがでしょうか?
違いがわかりましたか?
そう、Gメールでは当たり前のように迷惑メールを弾いてくれますが、これはAIの力なのです。(アルゴリズムはもっと複雑です。)
けど、やっていることは同じで、ある直線(面)を引いてスパムか有効なメールかを分析しているのです。
美しいし、凄いですよね。
さて、今日の最終の目的の画像を見せますね。
以下になります。
⚠︎注意したいことがあります。
それはこの例の場合だと、無限の線が引けると言うことです。
よって、この点線の幅(マージ)が一番広くなった上で、その真ん中に実線(決定境界)を引きます。
さてこの決定境界を求める為のコードを見ていきます。
サンプルコード
import numpy as np
from sklearn.svm import SVC
import matplotlib.pyplot as plt
xmin = 0
xmax = 2
ymin = 0
ymax = 2
# 0 ~ 1 までの正規分布に従った値が入った(100行2列(X&Y))の行列を作成
X0 = np.random.uniform(size=(100, 2))
y0 = np.repeat(0, 100)
X1 = np.random.uniform(low=1.0, high=2.0, size=(100, 2))
y1 = np.repeat(1, 100)
svc = SVC(kernel='linear', C=1e8)
# 学習
# svc.fit(train_features, train_labels)
svc.fit(np.vstack((X0, X1)), np.hstack((y0, y1)))
fig, ax = plt.subplots()
# ax.scatter(x, y) 散布図
ax.scatter(X0[:, 0], X0[:, 1], marker='o', label='class 0')
ax.scatter(X1[:, 0], X1[:, 1], marker='x', label='class 1')
# 格子点を準備
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 100), np.linspace(ymin, ymax, 100))
# decision_functionの為に一つの行列に。
xy = np.vstack([xx.ravel(), yy.ravel()]).T
# 100 * 100 の行列。
# z = x*2 + y*2 とする部分のzの高さの部分。
# svcの決定関数を使ってxyという格子点からzを求めている。
z = svc.decision_function(xy).reshape((100, 100))
# 決定境界とマージンをプロット
ax.contour(xx, yy, z, colors='k', levels=[-1, 0, 1], alpha=0.3, linestyles=['--', '-', '--'])
plt.show()
流れとしては、
- データを準備
- SVCのインスタンスを作成
- 入力データをそれを紐付かせたラベルをfit関数に入れることで学習させる。
- 格子点を作成(参考記事)
- Zをfitで学んだdecision_functionを使って作成
- contourを使って等高線を引くと、決定境界を書くことができる。
大切な関数はfitとdecition_funkctionである。
fitの使い方は理解したが、decision_functionの方は分類器を利用してZを求めることくらいしか理解できてない。
それはもう少し勉強が進んで記事にしようと思う。
今日は夜が遅いのでこれくらいで。
おやすみなさい。