0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

『機械学習のエッセンス(http://isbn.sbcr.jp/93965/)』のPythonサンプルをJuliaで書き換えてみる。(第05章07サポートベクタマシン)

Last updated at Posted at 2019-04-07

はじめに

『機械学習のエッセンス(http://isbn.sbcr.jp/93965/)』のPythonサンプルをJuliaで書き換えてみる。(第05章06ロジスティック回帰)の続きです。

サポートベクタマシン

ここでの説明はかなり長いので、ピンポイントで数式だけいくつか書きます。これだけだと分かりづらいと思いますので、詳しくは元のをご参照ください。

分類直線(一般には超曲面)が次で表されるとする。

y = w_0 + \sum_{i=1}^n w_ix_i = w_0 + {w}^Tx

変数$a$を導入。

a = (a_1,a_2,\dots,a_n)^T

次のような2次計画問題を解く

  • 式05-16
\begin{alignat}{2}
\text{Maximize} &\quad f(a) = \sum_{k=1}^n a_k - \frac{1}{2}\sum_{k=1}^n\sum_{l=1}^n a_{k}a_{l}y_{k}y_{l}x_{k}^Tx_{l} \\
\text{Subject to} &\quad  \sum_{i=1}^n a_{i}y_{i} = 0 \\
                  &\quad a_i \ge 0  \\
                 &\quad a_{i}\{y_i(w_0 + w^Tx_i) -1 \} = 0 \\
\end{alignat}

式05-16の最後の制約式により$y_i(w_0+w^Tx_i) \ne 1$の時は、$a_i = 0$となる。$y_i(w_0+w^Tx_i) = 1$というのは点がマージン境界線の上に乗っているという条件。マージン境界線の上にある点のことをサポートベクタという。

アルゴリズムの概要

  • 初期値$a^0$を選び、以下を繰り返す。
  1. ある基準にもどづきインデックス$i$,$j$を選択する。
  2. $a_i$と$a_j$だけを動かして、他を固定して最適な$a_i$,$a_j$を求める。

$a_i \gt 0$, $a_j \gt 0$という制約条件を無視したときに目的関数を最大化する$a_i$を$\hat{a}_i$とすると、$\hat{a}_i$は次で表される。

\hat{a}_i = \frac{1}{||x_i - x_j||^2}\{1 - y_{i}y_{j} + y_{i}(x_i - x_j)^T(x_j\sum_{k \ne i,j}a_{k}y_{k} - \sum_{k \ne i,j}a_{k}y_{k}x_{k}) \}

対応する$\hat{a}_j$は次で決まる。

\hat{a}_j = y_j(-\hat{a}_{i}y_{i} - \sum_{k \ne i,j}a_{k}y_{k})
  • $i$と$j$の選択のしかた
i = \underset{t \in I_{-}(y,a)}{argmin} \ y_t \nabla {f(a)}_{t} \\
j = \underset{t \in I_{+}(y,a)}{argmax} \ y_t \nabla {f(a)}_{t} 

argminとargmaxはそれぞれ、最小値・最大値を取るときのインデックスの値。
また、$I_{-}$と$I_{+}$は次のように定める。

I_{-}(y,a) = \{t \ | \ y_t = -1 または a_t > 0 \}  \\
I_{+}(y,a) = \{t \ | \ y_t = 1 または a_t > 0 \}
S = \{i \ | \ a_i \ne 0 \}

とすると$w_0$は下記で計算できる。

  • 式05-21
w_0 = \frac{1}{|S|}\sum_{k \in S}(y_k - \sum_{l \in S}a_{l}y_{l}x_{k}^Tx_{l})

サポートベクタマシンの実装

svm_hard.jl
module svm_hard
using LinearAlgebra

mutable struct SVC
    a_
    w_
    w0_
    function SVC()
        new(Nothing, Nothing, Nothing)
    end
end

function fit(s::SVC, X, y)
    a = zeros(size(X)[1])
    ay = 0
    ayx = zeros(size(X)[2])
    yx = y .* X
    while true
        ydf = y .* (1 .- (yx * ayx))
        i = findfirst(ydf .== minimum(ydf[(y .< 0) .| (a .> 0)]))
        j = findfirst(ydf .== maximum(ydf[(y .> 0) .| (a .> 0)]))
        if ydf[i] >= ydf[j]
            break
        end
        ay2 = ay - y[i]*a[i] - y[j]*a[j]
        ayx2 = ayx .- y[i]*a[i].*X[i, :] .- y[j]*a[j].*X[j, :]
        ai = (1 - y[i]*y[j] + y[i] * dot( (X[i, :] .- X[j, :]) , (X[j, :] .* ay2 .- ayx2) ) ) / sum((X[i, :] .- X[j, :]).^2)
        ai = (ai < 0 ? 0 : ai)
        aj = (-ai * y[i] - ay2) * y[j]
        if aj < 0
            aj = 0
            ai = (-aj * y[j] - ay2) * y[i]
        end

        ay = ay + y[i] * (ai - a[i]) + y[j] * (aj -a[j])
        ayx = ayx .+ y[i] * (ai - a[i]) .* X[i, :] + y[j] * (aj -a[j]) .* X[j, :]
        if ai == a[i]
            break
        end
        a[i] = ai
        a[j] = aj
    end
    s.a_ = a
    ind = a .!= 0.
    s.w_ = sum((a[ind] .* y[ind]) .* X[ind, :], dims=1)
    s.w0_ = sum(y[ind] .- (X[ind, :] * (s.w_)')) / sum(ind)
end

function predict(s::SVC, X)
    sign.(s.w0_ .+ X * (s.w_)')
end

end
svm_hard_test1.jl
include("svm_hard.jl")

using .svm_hard
using Plots
using Random

Random.seed!(0)
X0 = randn(20, 2)
X1 = randn(20, 2) .+ [5 5]
y = vcat([1 for x in 1:20], [-1 for x in 1:20])

X = vcat(X0, X1)

model = svm_hard.SVC()
svm_hard.fit(model, X, y)

scatter(X0[:, 1], X0[:, 2], color="black", markershape=:+, label="")
scatter!(X1[:, 1], X1[:, 2], color="black", markershape=:star6, label="")

f(model, x) = (-model.w0_ - model.w_[1] * x) / model.w_[2]

x1 = -0.2
x2 = 6
plot!([x1, x2], [f(model, x1), f(model, x2)], color="black", label="")
tf = model.a_ .!= 0
Xfalse = [false for x in 1:size(X)[1]]
scatter!(X[hcat(tf, Xfalse)], X[hcat(Xfalse, tf)], color="red", markersize=10, markershape=:circle, markeralpha=0.1, label="")

実行結果

julia> include("svm_hard_test1.jl")

スクリーンショット 2019-04-07 20.42.14.png

  • サポートベクタを白抜きの丸で囲む描画がうまくいかず、赤色の透明度を上げて塗っています。

補足1(インデックスの取得について)

本のPythonのサンプルでは最小値、最大値のインデックスを取得するためにiydfという元のインデックスを保持した配列を作っています。Juliaで同じように作ろうとしたのですがなかなかうまくいかず、代わりにインデックスの取得の方法を調べたところいろいろ書き方があったので下記にまとめました。
Juliaの配列で最小値・最大値のインデックスを取得する
svm_hard.jlijはこの方法で取得しています。下記に簡単なサンプルを載せておきます。(実際に使用したのはfindfirst関数ですが、下記のサンプルではfindall関数を使用しています。)

julia> A = [4,2,3,5,9,1,6]                               
7-element Array{Int64,1}:
 4    
 2    
 3    
 5    
 9    
 1    
 6 

julia> Atf = [true, false, true, false, true, false,true]
7-element Array{Bool,1}:
  true
 false
  true
 false
  true
 false
  true

julia> A[Atf]
4-element Array{Int64,1}:
 4
 3
 9
 6

julia> findall(A .== maximum(A))
1-element Array{Int64,1}:
 5

julia> findall(A .== maximum(A[Atf]))
1-element Array{Int64,1}:
 5

補足2(最初に与えるランダムな値について)

ここのロジックを追うのはなかなか難しく、実際に動いた後も正しいか不安だったので、本のPythonサンプルが最初にRandomで作る配列を埋め込んで確認してみました。

svm_hard_test1.jl
(略)

Random.seed!(0)
#X0 = randn(20, 2)
X0 = [ 1.76405235  0.40015721;
 0.97873798  2.2408932 ;
 1.86755799 -0.97727788;
 0.95008842 -0.15135721;
-0.10321885  0.4105985 ;
 0.14404357  1.45427351;
 0.76103773  0.12167502;
 0.44386323  0.33367433;
 1.49407907 -0.20515826;
 0.3130677  -0.85409574;
-2.55298982  0.6536186 ;
 0.8644362  -0.74216502;
 2.26975462 -1.45436567;
 0.04575852 -0.18718385;
 1.53277921  1.46935877;
 0.15494743  0.37816252;
-0.88778575 -1.98079647;
-0.34791215  0.15634897;
 1.23029068  1.20237985;
-0.38732682 -0.30230275]
#X1 = randn(20, 2) .+ [5 5]
X1 = [3.95144703 3.57998206;
3.29372981 6.9507754 ;
4.49034782 4.5619257 ;
3.74720464 5.77749036;
3.38610215 4.78725972;
4.10453344 5.3869025 ;
4.48919486 3.81936782;
4.97181777 5.42833187;
5.06651722 5.3024719 ;
4.36567791 4.63725883;
4.32753955 4.64044684;
4.18685372 3.2737174 ;
5.17742614 4.59821906;
3.36980165 5.46278226;
4.09270164 5.0519454 ;
5.72909056 5.12898291;
6.13940068 3.76517418;
5.40234164 4.31518991;
4.12920285 4.42115034;
4.68844747 5.05616534]
y = vcat([1 for x in 1:20], [-1 for x in 1:20])

X = vcat(X0, X1)

(略)

実行結果

julia> include("svm_hard_test1.jl")

スクリーンショット 2019-04-07 21.01.37.png

  • 描画のサポートベクタも本と同じ結果になっています。

分離不可能な場合

スラックス変数$\xi_i \ge 0 \ (i = 1,\dots,n)$を考える。

y_i(w_0 + w^Tx) \ge 1 - \xi_i

解くべき最適化問題。$C$は調整のための定数。

\begin{alignat}{2}
\text{Minimize} &\quad C\sum_{i=1}^n \xi_i + \frac{1}{2}\sum_{i=1}^n ||w||^2 \\
\text{Subject to} &\quad  y_i(w_0 + w^Tx) \ge 1 - \xi_i \ (i = 1,\dots,n) \\
                  &\quad  \xi_i \ge 0 \ (i = 1,\dots,n)
\end{alignat}

解くべき最適化問題をまとめると次のようにになる。

  • 式05-22
\begin{alignat}{2}
\text{Maximize} &\quad f(a) = \sum_{k=1}^n a_k - \frac{1}{2}\sum_{k=1}^n\sum_{l=1}^n a_{k}a_{l}y_{k}y_{l}x_{k}^Tx_{l} \\
\text{Subject to} &\quad  \sum_{i=1}^n a_{i}y_{i} = 0 \\
                  &\quad 0 \le a_i \le C \\
\end{alignat}
I_{-}(y,a) = \{t \ | \ (a_t \gt 0 かつ y_t = 1) または (a_t \lt C かつ y_t = -1) \}  \\
I_{+}(y,a) = \{t \ | \ (a_t \gt 0 かつ y_t = -1) または (a_t \lt C かつ y_t = 1)  \}

とおけば

i = \underset{t \in I_{-}(y,a)}{argmin} \ y_t \nabla {f(a)}_{t} \\
j = \underset{t \in I_{+}(y,a)}{argmax} \ y_t \nabla {f(a)}_{t} 

実装

svm_soft.jl
module svm_soft
using LinearAlgebra

mutable struct SVC
    a_
    w_
    w0_
    C
    max_iter::Int
    function SVC(C=1.)
        new(Nothing, Nothing, Nothing, C, 10000)
    end
end

function fit(s::SVC, X, y)
    a = zeros(size(X)[1])
    ay = 0
    ayx = zeros(size(X)[2])
    yx = y .* X
    for i in 1:s.max_iter
        ydf = y .* (1 .- (yx * ayx))
        i = findfirst(ydf .== minimum(ydf[((a .> 0) .& (y .> 0)) .| ((a .< s.C) .& (y .< 0))]))
        j = findfirst(ydf .== maximum(ydf[((a .> 0) .& (y .< 0)) .| ((a .< s.C) .& (y .> 0))]))
        if ydf[i] >= ydf[j]
            break
        end
        ay2 = ay - y[i]*a[i] - y[j]*a[j]
        ayx2 = ayx .- y[i]*a[i].*X[i, :] .- y[j]*a[j].*X[j, :]
        ai = (1 - y[i]*y[j] + y[i] * dot( (X[i, :] .- X[j, :]) , (X[j, :] .* ay2 .- ayx2) ) ) / sum((X[i, :] .- X[j, :]).^2)
        if ai < 0
            ai = 0
        elseif ai > s.C
            ai = s.C
        end
        aj = (-ai * y[i] - ay2) * y[j]
        if aj < 0
            aj = 0
            ai = (-aj * y[j] - ay2) * y[i]
        elseif aj > s.C
            aj = s.C
            ai = (-aj * y[j] - ay2) * y[i]
        end
        ay = ay + y[i] * (ai - a[i]) + y[j] * (aj -a[j])
        ayx = ayx .+ y[i] * (ai - a[i]) .* X[i, :] + y[j] * (aj -a[j]) .* X[j, :]
        if ai == a[i]
            break
        end
        a[i] = ai
        a[j] = aj
    end
    s.a_ = a
    ind = a .!= 0.
    s.w_ = sum((a[ind] .* y[ind]) .* X[ind, :], dims=1)
    s.w0_ = sum(y[ind] .- (X[ind, :] * (s.w_)')) / sum(ind)
end

function predict(s::SVC, X)
    sign.(s.w0_ .+ X * (s.w_)')
end

end
svm_soft_test1.jl
include("svm_soft.jl")

using .svm_soft
using Plots
using Random

Random.seed!(0)
X0 = randn(20, 2)
X1 = randn(20, 2) .+ [2.5 3]

y = vcat([1 for x in 1:20], [-1 for x in 1:20])

X = vcat(X0, X1)

model = svm_soft.SVC()
svm_soft.fit(model, X, y)

scatter(X0[:, 1], X0[:, 2], color="black", markershape=:+, label="")
scatter!(X1[:, 1], X1[:, 2], color="black", markershape=:star6, label="")

f(model, x) = (-model.w0_ - model.w_[1] * x) / model.w_[2]

x1 = -2
x2 = 4
plot!([x1, x2], [f(model, x1), f(model, x2)], color="black", label="")
tf = model.a_ .!= 0
Xfalse = [false for x in 1:size(X)[1]]
print("正しく分類できた数:", sum(svm_soft.predict(model, X) .== y))
scatter!(X[hcat(tf, Xfalse)], X[hcat(Xfalse, tf)], color="red", markersize=10, markershape=:circle, markeralpha=0.1, label="")

実行結果

julia> include("svm_soft_test1.jl")
正しく分類できた数:38

スクリーンショット 2019-04-07 21.22.25.png

補足1

本の実装と違い、max_iterでループの最大値を設定しています。これは実際に動かしたときに無限ループになってしまったからです。実装が間違っていると思ったのですが、下記の補足2でPythonサンプルで最初に作るRandomな配列を埋め込んで動かしたところうまく動きました。ロジックは合っているものとして、無限ループを防ぐために最大値を設定しました。

(追記)@tenfu2tea さんからコメントをいただきました。無限ループの原因は svm_soft.jl

if ai == a[i]
  break
end

の所が浮動小数点数同士の比較になっていたからでした。これは第04章01数値計算の基本でも指摘されていることで、見逃していました。もともとのPythonサンプルもこの比較方法になっているので、たまたま一致して無限ループにならないだけではないかと思います。
教えていただいた isapprox 関数を使うことで無限ループを回避出来ました。
https://docs.julialang.org/en/v1/base/math/#Base.isapprox
本文のソースコードはそのままにしておきます。回避した内容はコメントをご参照ください。(次章以降は書き換えも isapprox 関数を利用する予定です。)

補足2(最初に与えるランダムな値について)

本のPythonサンプルが最初にRandomで作る配列を埋め込んで確認してみました。

svm_soft_test1.jl
(略)

Random.seed!(0)
#X0 = randn(20, 2)
X0 = [ 1.76405235  0.40015721;
  0.97873798  2.2408932 ;
  1.86755799 -0.97727788;
  0.95008842 -0.15135721;
 -0.10321885  0.4105985 ;
  0.14404357  1.45427351;
  0.76103773  0.12167502;
  0.44386323  0.33367433;
  1.49407907 -0.20515826;
  0.3130677  -0.85409574;
 -2.55298982  0.6536186 ;
  0.8644362  -0.74216502;
  2.26975462 -1.45436567;
  0.04575852 -0.18718385;
  1.53277921  1.46935877;
  0.15494743  0.37816252;
 -0.88778575 -1.98079647;
 -0.34791215  0.15634897;
  1.23029068  1.20237985;
 -0.38732682 -0.30230275;]

#X1 = randn(20, 2) .+ [2.5 3]
X1 = [ 1.45144703 1.57998206;
 0.79372981 4.9507754 ;
 1.99034782 2.5619257 ;
 1.24720464 3.77749036;
 0.88610215 2.78725972;
 1.60453344 3.3869025 ;
 1.98919486 1.81936782;
 2.47181777 3.42833187;
 2.56651722 3.3024719 ;
 1.86567791 2.63725883;
 1.82753955 2.64044684;
 1.68685372 1.2737174 ;
 2.67742614 2.59821906;
 0.86980165 3.46278226;
 1.59270164 3.0519454 ;
 3.22909056 3.12898291;
 3.63940068 1.76517418;
 2.90234164 2.31518991;
 1.62920285 2.42115034;
 2.18844747 3.05616534;]

y = vcat([1 for x in 1:20], [-1 for x in 1:20])

X = vcat(X0, X1)

(略)

実行結果

julia> include("svm_soft_test1.jl")
WARNING: replacing module svm_soft.
正しく分類できた数:37

スクリーンショット 2019-04-07 21.27.50.png

  • 正しく分離できた数も本と同じ結果です。
  • 描画のサポートベクタも本と同じ結果になっています。

カーネル法

カーネル法ではある写像$\phi:\mathbb{R}^d \to \mathbb{R}^d$を用いて、分離超曲面が次で表されると仮定する。

w_0 + w^T\phi(x)

この場合、式05-22にあたる最適化問題は次のようになる。

\begin{alignat}{2}
\text{Maximize} &\quad f(a) = \sum_{k=1}^n a_k - \frac{1}{2}\sum_{k=1}^n\sum_{l=1}^n a_{k}a_{l}y_{k}y_{l}\phi(x_{k})^T\phi(x_{l}) \\
\text{Subject to} &\quad  \sum_{i=1}^n a_{i}y_{i} = 0 \\
                  &\quad 0 \le a_i \le C \\
\end{alignat}

ここで

\phi(x_{k})^T\phi(x_{l}) = K(X_k, X_l)

とおくと目的関数は次のようになる。

f(a) = \sum_{k=1}^n a_k - \frac{1}{2}\sum_{k=1}^n\sum_{l=1}^n a_{k}a_{l}y_{k}y_{l}K(x_k, x_l)

この関数$K$はカーネル関数と呼ばれる。

式05-21と同様な計算をカーネル関数を使って行う。

w_0 = \frac{1}{|S|}\sum_{k \in S}(y_k - \sum_{l \in S}a_{l}y_{l}K(x_k, x_l))

カーネル関数$K$でよく使われるのはRBF(放射基底関数)で、次で定義される。

K(u, v) = exp(-\frac{||u - v||^2}{2\sigma^2})

実装

svm.jl
module svm
using LinearAlgebra

mutable struct RBFKernel
    X
    σ2
    values_
    function RBFKernel(X, σ)
        new(X, σ^2, similar(Array{Float64}, (size(X)[1],size(X)[1])))
    end
end

value(obj::RBFKernel, i, j)  = exp((sum(-(obj.X[i, :] .- obj.X[j, :]).^2)) / (2*obj.σ2))

function eval(obj::RBFKernel, Z, s)
    XX = obj.X[s, :]
    X_Z = [sum((XX[i, :] .- Z[j, :]).^2) for i in 1:size(XX)[1], j in 1:size(Z)[1]]
    col, row = size(X_Z)
    if (col, row) == (0, 0)
        return []
    end
    exp.(-X_Z ./ (2*obj.σ2))
end

mutable struct SVC
   a_
   w0_
   y_
   kernel_
   C::Float64
   σ
   max_iter::Int
   function SVC()
       new(Nothing, Nothing, Nothing, Nothing, 1., 1, 10000)
   end
end

function fit(obj::SVC, X, y)
    a = zeros(size(X)[1])
    ay = 0
    kernel = RBFKernel(X, obj.σ)
    for i in 1:obj.max_iter
        s = a .!= 0.
        if isempty(a[s].*y[s]) || isempty(svm.eval(kernel, X, s))
            ydf = y
            i = findfirst(ydf .== minimum(ydf[(((a .> 0) .& (y .> 0)) .| ((a .< obj.C) .& (y .< 0)))]))
            j = findfirst(ydf .== maximum(ydf[(((a .> 0) .& (y .< 0)) .| ((a .< obj.C) .& (y .> 0)))]))
        else
            ydf = y .* (1 .- (y .* ((a[s].*y[s])' * svm.eval(kernel, X, s))'))
            i = findfirst(ydf .== minimum(ydf[(((a .> 0) .& (y .> 0)) .| ((a .< obj.C) .& (y .< 0)))]))
            j = findfirst(ydf .== maximum(ydf[(((a .> 0) .& (y .< 0)) .| ((a .< obj.C) .& (y .> 0)))]))
        end
        if ydf[i] >= ydf[j]
            break
        end

        ay2 = ay - y[i]*a[i] - y[j]*a[j]
        kii = svm.value(kernel, i, i)
        kij = svm.value(kernel, i, j)
        kjj = svm.value(kernel, j, j)
        s = a .!= 0.
        s[i] = false
        s[j] = false
        kxi = (svm.eval(kernel, reshape(X[i, :], 1, :), s))
        kxj = (svm.eval(kernel, reshape(X[j, :], 1, :), s))
        ai = (1 - y[i]*y[j] + y[i]*( (kij - kjj)*ay2 - sum( (a[s].*y[s].*(kxi .- kxj) == []) ? 0 : a[s].*y[s].*(kxi .- kxj) ) ) ) / (kii + kjj - 2*kij)
        if ai < 0
            ai = 0
        elseif ai > obj.C
            ai = obj.C
        end

        aj = (-ai*y[i] - ay2)*y[j]
        if aj < 0
            aj = 0
            ai = (-ai*y[j] - ay2)*y[i]
        elseif aj > obj.C
            aj = obj.C
            ai = (-ai*y[j] - ay2)*y[i]
        end
        ay = ay + y[i] * (ai - a[i]) + y[j] * (aj -a[j])
        if ai == a[i]
            break
        end
        a[i] = ai
        a[j] = aj
    end
    obj.a_ = a
    obj.y_ = y
    obj.kernel_ = kernel
    s = a .!= 0.
    obj.w0_ = sum(y[s] - ((a[s].*y[s])' * eval(kernel, X[s, :], s))') / sum(s)
end

function predict(obj::SVC, X)
    s = obj.a_ .!= 0.
    sign.(obj.w0_ .+ ((obj.a_[s].*obj.y_[s])' * eval(obj.kernel_, X, s))')
end

end
svm_kernel_test1.jl
include("svm.jl")

using .svm
using Plots
using Random

Random.seed!(0)
X0 = randn(100, 2)
X1 = randn(100, 2) .+ [2.5 3]

y = vcat([1 for x in 1:100], [-1 for x in 1:100])

X = vcat(X0, X1)

model = svm.SVC()
svm.fit(model, X, y)

xmin = minimum(X[:,1])
xmax = maximum(X[:,1])
ymin = minimum(X[:,2])
ymax = maximum(X[:,2])

x_range = LinRange(xmin, xmax, 200)
y_range = LinRange(ymin, ymax, 200)
xmesh = repeat(x_range', outer=(length(y_range),1))
ymesh = repeat(y_range,  outer=(1,length(x_range)))

Z = reshape(svm.predict(model, hcat(vec(xmesh), vec(ymesh))), size(xmesh))

print("正しく分類できた数:", sum(svm.predict(model, X) .== y))

scatter(X0[:, 1], X0[:, 2], color="black", markershape=:star6, label="")
scatter!(X1[:, 1], X1[:, 2], color="black", markershape=:+, label="")
contour!(x_range, y_range, Z, levels=[0], color=:black)

実行結果

julia> include("svm_kernel_test1.jl")
正しく分類できた数:198

スクリーンショット 2019-04-07 21.45.02.png

補足1

eval関数やydfの計算のときにブランクの配列([])があるとエラーになったため条件分岐で回避しました。もっといい書き方があったかもしれません。

補足2(最初に与えるランダムな値について)

本のPythonサンプルが最初にRandomで作る配列を埋め込んで確認してみました。

svm_kernel_test1.jl
(略)

Random.seed!(0)
#X0 = randn(100, 2)
X0 = [1.76405235  0.40015721;
 0.97873798  2.2408932 ;
 1.86755799 -0.97727788;
 0.95008842 -0.15135721;
-0.10321885  0.4105985 ;
 0.14404357  1.45427351;
 0.76103773  0.12167502;
 0.44386323  0.33367433;
 1.49407907 -0.20515826;
 0.3130677  -0.85409574;
-2.55298982  0.6536186 ;
 0.8644362  -0.74216502;
 2.26975462 -1.45436567;
 0.04575852 -0.18718385;
 1.53277921  1.46935877;
 0.15494743  0.37816252;
-0.88778575 -1.98079647;
-0.34791215  0.15634897;
 1.23029068  1.20237985;
-0.38732682 -0.30230275;
-1.04855297 -1.42001794;
-1.70627019  1.9507754 ;
-0.50965218 -0.4380743 ;
-1.25279536  0.77749036;
-1.61389785 -0.21274028;
-0.89546656  0.3869025 ;
-0.51080514 -1.18063218;
-0.02818223  0.42833187;
 0.06651722  0.3024719 ;
-0.63432209 -0.36274117;
-0.67246045 -0.35955316;
-0.81314628 -1.7262826 ;
 0.17742614 -0.40178094;
-1.63019835  0.46278226;
-0.90729836  0.0519454 ;
 0.72909056  0.12898291;
 1.13940068 -1.23482582;
 0.40234164 -0.68481009;
-0.87079715 -0.57884966;
-0.31155253  0.05616534;
-1.16514984  0.90082649;
 0.46566244 -1.53624369;
 1.48825219  1.89588918;
 1.17877957 -0.17992484;
-1.07075262  1.05445173;
-0.40317695  1.22244507;
 0.20827498  0.97663904;
 0.3563664   0.70657317;
 0.01050002  1.78587049;
 0.12691209  0.40198936;
 1.8831507  -1.34775906;
-1.270485    0.96939671;
-1.17312341  1.94362119;
-0.41361898 -0.74745481;
 1.92294203  1.48051479;
 1.86755896  0.90604466;
-0.86122569  1.91006495;
-0.26800337  0.8024564 ;
 0.94725197 -0.15501009;
 0.61407937  0.92220667;
 0.37642553 -1.09940079;
 0.29823817  1.3263859 ;
-0.69456786 -0.14963454;
-0.43515355  1.84926373;
 0.67229476  0.40746184;
-0.76991607  0.53924919;
-0.67433266  0.03183056;
-0.63584608  0.67643329;
 0.57659082 -0.20829876;
 0.39600671 -1.09306151;
-1.49125759  0.4393917 ;
 0.1666735   0.63503144;
 2.38314477  0.94447949;
-0.91282223  1.11701629;
-1.31590741 -0.4615846 ;
-0.06824161  1.71334272;
-0.74475482 -0.82643854;
-0.09845252 -0.66347829;
 1.12663592 -1.07993151;
-1.14746865 -0.43782004;
-0.49803245  1.92953205;
 0.94942081  0.08755124;
-1.22543552  0.84436298;
-1.00021535 -1.5447711 ;
 1.18802979  0.31694261;
 0.92085882  0.31872765;
 0.85683061 -0.65102559;
-1.03424284  0.68159452;
-0.80340966 -0.68954978;
-0.4555325   0.01747916;
-0.35399391 -1.37495129;
-0.6436184  -2.22340315;
 0.62523145 -1.60205766;
-1.10438334  0.05216508;
-0.739563    1.5430146 ;
-1.29285691  0.26705087;
-0.03928282 -1.1680935 ;
 0.52327666 -0.17154633;
 0.77179055  0.82350415;
 2.16323595  1.33652795]

#X1 = randn(100, 2) .+ [2.5 3]
X1 =[ 2.13081816  2.76062082;
 3.5996596   3.65526373;
 3.14013153  1.38304396;
 2.47567388  2.26196909;
 2.7799246   2.90184961;
 3.41017891  3.31721822;
 3.28632796  2.5335809 ;
 1.55555374  2.58995031;
 2.48297959  3.37915174;
 4.75930895  2.95774285;
 1.544055    2.65401822;
 2.03640403  3.48148147;
 0.95920299  3.06326199;
 2.65650654  3.23218104;
 1.90268393  2.76207827;
 1.07593909  2.50668012;
 1.95713852  3.41605005;
 1.34381757  3.7811981 ;
 3.99448454  0.93001497;
 2.92625873  3.67690804;
 1.86256297  2.60272819;
 2.36711942  2.70220912;
 2.19098703  1.32399619;
 3.65233156  4.07961859;
 1.68663574  1.53357567;
 3.02106488  2.42421203;
 2.64195316  2.68067158;
 3.19153875  3.69474914;
 1.77440262  1.61663604;
 0.9170616   3.61037938;
 1.31114074  2.49318365;
 1.90368596  2.9474327 ;
 0.56372019  3.1887786 ;
 3.02389102  3.08842209;
 2.18911383  3.09740017;
 2.89904635  0.22740724;
 4.45591231  3.39009332;
 1.84759142  2.60904662;
 2.99374178  2.88389606;
 0.46931553  5.06449286;
 2.38945934  4.02017271;
 1.80795015  4.53637705;
 2.78634369  3.60884383;
 1.45474663  4.21114529;
 3.18981816  4.30184623;
 1.87191244  2.51897288;
 4.8039167   1.93998418;
 2.3640503   4.13689136;
 2.59772497  3.58295368;
 2.10055097  3.37005589;
 1.19347315  4.65813068;
 2.38183595  2.3198218 ;
 3.16638308  2.53928021;
 1.16574153  1.65328249;
 3.19377315  2.84042656;
 2.36629844  4.07774381;
 1.37317419  2.26932225;
 2.11512019  3.09435159;
 2.45782855  2.71311281;
 2.4383736   2.89269472;
 1.78039561  2.18700701;
 2.77451636  2.10908492;
 1.34264474  2.68770775;
 2.34233298  5.2567235 ;
 1.79529972  3.94326072;
 3.24718833  1.81105504;
 3.27325298  1.81611936;
-0.15917224  3.60631952;
 0.74410942  3.45093446;
 1.8159891   4.6595508 ;
 3.5685094   2.5466142 ;
 1.81216239  1.7859226 ;
 2.05907737  2.7196445 ;
 2.13530646  3.15670386;
 3.0785215   3.34965446;
 1.73585608  1.56220853;
 3.86453185  2.31055082;
 1.8477064   2.47881069;
 0.65693045  2.522026  ;
 2.02034419  3.6203583 ;
 3.19845715  3.00377089;
 3.43184837  3.33996498;
 2.48431789  3.16092817;
 2.30934651  2.60515049;
 2.23226646  1.87198867;
 2.78044171  2.00687639;
 3.34163126  2.75054142;
 2.54949498  3.49383678;
 3.14331447  1.42937659;
 2.29309632  3.88017891;
 0.80189418  3.38728048;
 0.24443577  1.97749316;
 2.53863055  1.3432849 ;
 1.51448926  1.52816499;
 4.14813493  3.16422776;
 3.06729028  2.7773249 ;
 2.14656825  1.38352581;
 2.20816264  2.23850779;
 3.35792392  4.14110187;
 3.96657872  3.85255194]

y = vcat([1 for x in 1:100], [-1 for x in 1:100])

X = vcat(X0, X1)

(略)

実行結果

julia> include("svm_kernel_test1.jl")
WARNING: replacing module svm.
正しく分類できた数:193

スクリーンショット 2019-04-07 21.52.14.png

  • 正しく分離できた数も本と同じ結果です。
  • 描画も本と同じ結果になっています。
0
1
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?