Julia v0.3.0でSVM.jlを使う

  • 8
    いいね
  • 0
    コメント
この記事は最終更新日から1年以上が経過しています。

exampleを実行したいだけなのに、またハマったのでメモがてら書きます。

最終的に修正したexampleはこちら。(現在PR投げてますがいつmergeされるやら...)

# To show how SVMs work, we'll use Fisher's iris data set
# Temporally, following order is required 
using RDatasets
using SVM


# We'll learn to separate setosa from other species
iris = dataset("datasets", "iris")

# SVM format expects observations in columns and features in rows
X = array(iris[:, 1:4])'
p, n = size(X)

# SVM format expects positive and negative examples to +1/-1
Y = [species == :setosa ? 1.0 : -1.0 for species in iris[:, "Species"]]

# Select a subset of the data for training, test on the rest.
train = randbool(n)

# We'll fit a model with all of the default parameters
model = svm(X[:,train], Y[train])

# And now evaluate that model on the testset
accuracy = countnz(predict(model, X[:,~train]) .== Y[~train])/countnz(~train)

ちなみに、比較対象のcommitは 20bd50ea72fc4f94ec9e4adebf48ec7f16c717ca

ポイントは2点。

1. RDatasetsの使い方が変わった

-iris = data("datasets", "iris")
+iris = dataset("datasets", "iris")

こちらの所沢さんの資料を読んでいたら、P.11に書いてあって気づいた。

2. predictがRDatasetsのものとconflictする

accuracy = countnz(predict(model, X[:,~train]) .== Y[~train])/countnz(~train)

ここの部分だけど、うまくいったり行かなかったりしたので凄い悩んだ。

元々のexampleだと、型が合わないとエラーが出てしまう

julia> accuracy = countnz(predict(model, X[:,~train]) .== Y[~train])/countnz(~train)
ERROR: no method predict(SVMFit, Array{Float64,2})
julia> using SVM
julia> using RDatasets
julia> methods(predict)
# 2 methods for generic function "predict":
predict(::DataFrameRegressionModel{M,T},...) at /Users/chezou/.julia/v0.3/DataFrames/src/statsmodels/statsmodel.jl:27
predict(obj::RegressionModel) at /Users/chezou/.julia/v0.3/StatsBase/src/statmodels.jl:20

どうやら、DataFramesのpredictを読んでしまっていたらしい。

(bicycle1885さんありがとうございます!)

これは、usingの順番を変えることで対応出来ました。

julia> using RDatasets
julia> using SVM
julia> methods(predict)
# 2 methods for generic function "predict":
predict(fit::SVMFit,X::Array{T,2}) at /Users/mic/.julia/v0.3/SVM/src/SVM.jl:22
predict(fit::SVMFit) at /Users/mic/.julia/v0.3/SVM/src/SVM.jl:33

methodsとかで調べるみたいな話、軽く使う分には知らないのでハマった時の対処法を身につけていきたい。

はー、さくっとやろうと思ったら、2,3時間はまってしまった。