0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (13)

Last updated at Posted at 2020-06-12

前回
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (12)
https://github.com/legacyworld/sklearn-basic

課題 6.2 カーネルとSVM

前回はガウシアンカーネルの部分まで出来なかったので引き続き課題 6.2を解いていく
ガウシアンカーネルでは2つのパラメータをチューニングしている。

  • 正則化パラメータのC
    • scikit-learnのsvm.SVCではCに反比例して正則化の効果が大きくなる
  • ハイパーパラメータ
    • $k(x,x') = \exp(- \frac{||x-x'||}{\sigma^2})$の$\frac{1}{\sigma^2}$の部分
    • プログラムでは$\gamma$(gamma)としている

講義ではハイパーパラメータであるgammaをプルダウンで選ぶと、正則化パラメータを0.01から300まで10刻み動かして最適な(=テスト誤差が最小)値を選んで描画している
プログラムでは全てのgamma(1000,100,1,0.1,0.01)にたいして一気にPNGに保存している。
因みに講義のプログラムではgammaは逆数になっている。

Homework_6.2_rbf.py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.colors as mcolors
from sklearn import svm,metrics
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_circles,make_moons,make_blobs

datanames = ['linear_separation','moons','circles']
samples = 200
c_values = [i/100 for i in range(1,30000,1000)]
# 3種類のデータ作成
def datasets(dataname):
    if dataname == 'linear_separation':
        X,y = make_blobs(n_samples=samples,centers=2,random_state=64)
    elif dataname == 'moons':
        X,y = make_moons(n_samples=samples,noise=0.3,random_state=74)
    elif dataname == 'circles':
        X,y = make_circles(n_samples=samples,noise=0.3,random_state=70)
    
    X = preprocessing.MinMaxScaler(feature_range=(-1,1)).fit_transform(X)
    return X,y
# 全てのCに対してaccuracy_scoreを返す関数
def learn_test_plot(clf_models):
    accs = {}
    for clf in clf_models:
        for dataname in datanames:
            if dataname not in accs:
                accs[dataname] = []
            X,y = datasets(dataname)
            X_tr_val,X_test,y_tr_val,y_test = train_test_split(X,y,test_size=0.3,random_state=42)
            X_tr,X_val,y_tr,y_val = train_test_split(X_tr_val,y_tr_val,test_size=0.2,random_state=42)
            clf.fit(X_tr,y_tr)
            predict = clf.predict(X_val)
            train_acc = metrics.accuracy_score(y_val,predict)
            accs[dataname].append(train_acc)
    return accs
# gammaを動かして描画
gamma_list = [1000,100,1,0.1,0.01]
for gamma in gamma_list:
    plt.clf()
    clf_models = [svm.SVC(kernel='rbf',gamma=gamma,C=c_value) for c_value in c_values]
    accs = learn_test_plot(clf_models)
    fig = plt.figure(figsize=(20,10))
    ax = [fig.add_subplot(2,3,i+1) for i in range(6)]
    for a in ax:
        a.set_xlim(-1.5,1.5)
        a.set_ylim(-1.5,1.5)
    for dataname in datanames:
        best_c_index = np.argmax(accs[dataname])
        X,y = datasets(dataname)
        X_tr_val,X_test,y_tr_val,y_test = train_test_split(X,y,test_size=0.3,random_state=42)
        X_tr,X_val,y_tr,y_val = train_test_split(X_tr_val,y_tr_val,test_size=0.2,random_state=42)
        clf = clf_models[best_c_index]
        
        clf.fit(X_tr,y_tr)
        train_predict = clf.predict(X_tr_val)
        test_predict = clf.predict(X_test)
        train_acc = metrics.accuracy_score(y_tr_val,train_predict) 
        test_acc = metrics.accuracy_score(y_test,test_predict)
        c_value = clf.get_params()['C']
        
        # メッシュデータ
        xlim = [-1.5,1.5]
        ylim = [-1.5,1.5]
        xx = np.linspace(xlim[0], xlim[1], 300)
        yy = np.linspace(ylim[0], ylim[1], 300)
        YY, XX = np.meshgrid(yy, xx)
        xy = np.vstack([XX.ravel(), YY.ravel()]).T
        Z = clf.decision_function(xy).reshape(XX.shape)
        # 塗りつぶし用の色
        blue_rgb = mcolors.to_rgb("tab:blue")
        red_rgb = mcolors.to_rgb("tab:red")
        # データセットごとに縦に並べる
        index = datanames.index(dataname)
        # decision_functionが大きいほど色を濃くする
        ax[index].contourf(XX, YY, Z,levels=[-2,-1,-0.1,0.1,1,2],colors=[red_rgb+(0.5,),red_rgb+(0.3,),(1,1,1),blue_rgb+(0.3,),blue_rgb+(0.5,)],extend='both')
        ax[index].contour(XX,YY,Z,levels=[0],linestyles=["--"])
        ax[index].scatter(X_tr_val[:,0],X_tr_val[:,1],c=y_tr_val,edgecolors='k',cmap=ListedColormap(['#FF0000','#0000FF']))
        ax[index].set_title(f"gamma = {gamma}\nTraining Accuracy = {train_acc} C = {c_value}")

        ax[index+3].contourf(XX, YY, Z,levels=[-2,-1,-0.1,0.1,1,2],colors=[red_rgb+(0.5,),red_rgb+(0.3,),(1,1,1),blue_rgb+(0.3,),blue_rgb+(0.5,)],extend='both')
        ax[index+3].contour(XX,YY,Z,levels=[0],linestyles=["--"])
        ax[index+3].scatter(X_test[:,0],X_test[:,1],c=y_test,edgecolors='k',cmap=ListedColormap(['#FF0000','#0000FF']))
        ax[index+3].set_title(f"gamma = {gamma}\nTest Accuracy = {test_acc} C = {c_value}")
    plt.savefig(f"6.2_{gamma}.png")

learn_test_plotの関数で全てのCに対してのaccuracy_scoreを返している。
そこから最も大きいaccuracy_scoreを返すCで描画する
$\gamma$が大きいと訓練データの影響の直径が小さくなる = 複雑な形になる
$\gamma$が大きい場合は決定境界は非常に複雑な形をしているが、$\gamma$が小さいと直線になってしまっている

make_blobsで作ったデータに対して、決定境界が各訓練データにまとわりつくようにすることはほぼ意味がない(塊が2つなら)。
逆にmake_circlesのデータに対しては、直線の決定境界は意味がない。
6.2_1000.png
6.2_1.png
6.2_0.01.png

過去の投稿

筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (1)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (2)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (3)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (4)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (5)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (6)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (7) 最急降下法を自作
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (8) 確率的最急降下法を自作
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (9)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (10)
筑波大学の機械学習講座:課題のPythonスクリプト部分を作りながらsklearnを勉強する (11)
https://github.com/legacyworld/sklearn-basic
https://ocw.tsukuba.ac.jp/course/systeminformation/machine_learning/

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?