何故こんなものを書いたのか
scalaだと、構文解析が書きやすいので・・・。
pythonを使わないで、scalaを使って機械学習をしてる人なんて少ないと思うので需要があるかどうか知りませんが置いておきます。
(他に良い書き方が絶対あると思う)
環境
この記事ではWindows8.1にscala2.11.7をインストールして使っています。
scalaの他に、sbt0.13.8も使用しています。
コード
name := """libsvmtest"""
version := "1.0"
scalaVersion := "2.11.7"
// Change this to another test framework if you prefer
libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.4" % "test"
// https://mvnrepository.com/artifact/com.datumbox/libsvm
libraryDependencies += "com.datumbox" % "libsvm" % "3.22"
import libsvm._
import scala.io.Source
import scala.collection.mutable
import scala.util.matching.Regex
import java.io.PrintWriter
object main{
def main(args: Array[String]){
val source = Source.fromFile(args(0), "MS932")
val uncategorized_source = Source.fromFile(args(1), "MS932")
var param = new svm_parameter()
param.svm_type=svm_parameter.C_SVC
param.kernel_type=svm_parameter.RBF
param.gamma=0.5
param.degree=3
param.coef0=0
param.nu=0.5
param.cache_size=20000
param.C=1
param.eps=0.001
param.p=0.1
var prob = new svm_problem()
prob.y = Array.empty//libsvmの関数、Array[Double]でlabelを入力する。
prob.x = Array.empty//libsvmの関数,Array(Array[svm_node])形式でデータ部分を入力する。
var inputarray : Array[svm_node] = Array.empty//prob.xに入力する用その2
//正規表現
var r1 = """(\d+)\s([\d+]:.+)""".r//(1 1:300 2:500~)のようなデータの場合、(1)と(1:300 2:500~) に分ける
var r2 = """(\d+):([\-\d\.\w]+)\s(\d+:.+)""".r//(1:300 2:500 3:700)のようなデータの場合は(1)と(300)と(2:500 3:700)に分ける。
var r3 = """(\d+):([\-\d\.\w]+)\s""".r//(3:700)のようなデータの場合に(3)と(700)に分ける。行末の処理用
def set_data_in_array(str:String):String=
str match{
case r1(t1,t2) =>
prob.y = prob.y :+ t1.toDouble
return set_data_in_array(t2)
case r2(t1,t2,t3) =>
var node = new svm_node//prob.xに入力する用
node.index = t1.toInt
node.value = BigDecimal(t2).bigDecimal.toPlainString.toDouble
inputarray = inputarray :+ node
return set_data_in_array(t3)
case r3(t1,t2) =>
var node = new svm_node//prob.xに入力する用
node.index = t1.toInt
node.value = BigDecimal(t2).bigDecimal.toPlainString.toDouble
inputarray = inputarray :+ node
node.index = -1//prob.xには配列の末のsvm_node.indexに-1を入力しなければならない。
inputarray = inputarray :+ node
prob.l=prob.l + 1//データ数を入力
prob.x = prob.x :+ inputarray
inputarray = Array.empty
return ""
}
var uncategorized_data: Array[svm_node] = Array.empty
def uncategorized_data_in_array(str:String):String=
str match{
case r2(t1,t2,t3) =>
var node = new svm_node//prob.xに入力する用
node.index = t1.toInt
node.value = BigDecimal(t2).bigDecimal.toPlainString.toDouble
uncategorized_data = uncategorized_data :+ node
return uncategorized_data_in_array(t3)
case r3(t1,t2) =>
var node = new svm_node//prob.xに入力する用
node.index = t1.toInt
node.value = BigDecimal(t2).bigDecimal.toPlainString.toDouble
uncategorized_data = uncategorized_data :+ node
node.index = -1
uncategorized_data = uncategorized_data :+ node
return ""
}
source.getLines().foreach{ line: String => set_data_in_array(line)}
var model = svm.svm_train(prob,param)
println(model)
uncategorized_source.getLines().foreach{ line: String => uncategorized_data_in_array(line)}
println("categorized:")
println(prob.l)
println(svm.svm_predict(model, uncategorized_data))
source.close()
}
}
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667
1 1:-0.666667 2:-0.166667 3:-0.864407 4:-0.916667
1 1:-0.777778 3:-0.898305 4:-0.916667
1 1:-0.833333 2:-0.0833334 3:-0.830508 4:-0.916667
1 1:-0.611111 2:0.333333 3:-0.864407 4:-0.916667
1 1:-0.388889 2:0.583333 3:-0.762712 4:-0.75
1 1:-0.833333 2:0.166667 3:-0.864407 4:-0.833333
1 1:-0.611111 2:0.166667 3:-0.830508 4:-0.916667
1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667
1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1
1 1:-0.388889 2:0.416667 3:-0.830508 4:-0.916667
1 1:-0.722222 2:0.166667 3:-0.79661 4:-0.916667
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1
1 1:-1 2:-0.166667 3:-0.966102 4:-1
1 1:-0.166667 2:0.666667 3:-0.932203 4:-0.916667
1 1:-0.222222 2:1 3:-0.830508 4:-0.75
1 1:-0.388889 2:0.583333 3:-0.898305 4:-0.75
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.833333
1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333
1 1:-0.555556 2:0.5 3:-0.830508 4:-0.833333
1 1:-0.388889 2:0.166667 3:-0.762712 4:-0.916667
1 1:-0.555556 2:0.416667 3:-0.830508 4:-0.75
1 1:-0.833333 2:0.333333 3:-1 4:-0.916667
1 1:-0.555556 2:0.0833333 3:-0.762712 4:-0.666667
1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667
1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667
1 1:-0.611111 2:0.166667 3:-0.79661 4:-0.75
1 1:-0.5 2:0.25 3:-0.830508 4:-0.916667
1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667
1 1:-0.777778 3:-0.79661 4:-0.916667
1 1:-0.722222 2:-0.0833334 3:-0.79661 4:-0.916667
1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75
1 1:-0.5 2:0.75 3:-0.830508 4:-1
1 1:-0.333333 2:0.833333 3:-0.864407 4:-0.916667
1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1
1 1:-0.611111 3:-0.932203 4:-0.916667
1 1:-0.333333 2:0.25 3:-0.898305 4:-0.916667
1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1
1 1:-0.944444 2:-0.166667 3:-0.898305 4:-0.916667
1 1:-0.555556 2:0.166667 3:-0.830508 4:-0.916667
1 1:-0.611111 2:0.25 3:-0.898305 4:-0.833333
1 1:-0.888889 2:-0.75 3:-0.898305 4:-0.833333
1 1:-0.944444 3:-0.898305 4:-0.916667
1 1:-0.611111 2:0.25 3:-0.79661 4:-0.583333
1 1:-0.555556 2:0.5 3:-0.694915 4:-0.75
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333
1 1:-0.555556 2:0.5 3:-0.79661 4:-0.916667
1 1:-0.833333 3:-0.864407 4:-0.916667
1 1:-0.444444 2:0.416667 3:-0.830508 4:-0.916667
1 1:-0.611111 2:0.0833333 3:-0.864407 4:-0.916667
2 1:0.5 3:0.254237 4:0.0833333
2 1:0.166667 3:0.186441 4:0.166667
2 1:0.444444 2:-0.0833334 3:0.322034 4:0.166667
2 1:-0.333333 2:-0.75 3:0.0169491 4:-4.03573e-08
2 1:0.222222 2:-0.333333 3:0.220339 4:0.166667
2 1:-0.222222 2:-0.333333 3:0.186441 4:-4.03573e-08
2 1:0.111111 2:0.0833333 3:0.254237 4:0.25
2 1:-0.666667 2:-0.666667 3:-0.220339 4:-0.25
2 1:0.277778 2:-0.25 3:0.220339 4:-4.03573e-08
2 1:-0.5 2:-0.416667 3:-0.0169491 4:0.0833333
2 1:-0.611111 2:-1 3:-0.152542 4:-0.25
2 1:-0.111111 2:-0.166667 3:0.0847457 4:0.166667
2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25
2 1:-1.32455e-07 2:-0.25 3:0.254237 4:0.0833333
2 1:-0.277778 2:-0.25 3:-0.118644 4:-4.03573e-08
2 1:0.333333 2:-0.0833334 3:0.152542 4:0.0833333
2 1:-0.277778 2:-0.166667 3:0.186441 4:0.166667
2 1:-0.166667 2:-0.416667 3:0.0508474 4:-0.25
2 1:0.0555554 2:-0.833333 3:0.186441 4:0.166667
2 1:-0.277778 2:-0.583333 3:-0.0169491 4:-0.166667
2 1:-0.111111 3:0.288136 4:0.416667
2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08
2 1:0.111111 2:-0.583333 3:0.322034 4:0.166667
2 1:-1.32455e-07 2:-0.333333 3:0.254237 4:-0.0833333
2 1:0.166667 2:-0.25 3:0.118644 4:-4.03573e-08
2 1:0.277778 2:-0.166667 3:0.152542 4:0.0833333
2 1:0.388889 2:-0.333333 3:0.288136 4:0.0833333
2 1:0.333333 2:-0.166667 3:0.355932 4:0.333333
2 1:-0.0555556 2:-0.25 3:0.186441 4:0.166667
2 1:-0.222222 2:-0.5 3:-0.152542 4:-0.25
2 1:-0.333333 2:-0.666667 3:-0.0508475 4:-0.166667
2 1:-0.333333 2:-0.666667 3:-0.0847458 4:-0.25
2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333
2 1:-0.0555556 2:-0.416667 3:0.38983 4:0.25
2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667
2 1:-0.0555556 2:0.166667 3:0.186441 4:0.25
2 1:0.333333 2:-0.0833334 3:0.254237 4:0.166667
2 1:0.111111 2:-0.75 3:0.152542 4:-4.03573e-08
2 1:-0.277778 2:-0.166667 3:0.0508474 4:-4.03573e-08
2 1:-0.333333 2:-0.583333 3:0.0169491 4:-4.03573e-08
2 1:-0.333333 2:-0.5 3:0.152542 4:-0.0833333
2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333
2 1:-0.166667 2:-0.5 3:0.0169491 4:-0.0833333
2 1:-0.611111 2:-0.75 3:-0.220339 4:-0.25
2 1:-0.277778 2:-0.416667 3:0.0847457 4:-4.03573e-08
2 1:-0.222222 2:-0.166667 3:0.0847457 4:-0.0833333
2 1:-0.222222 2:-0.25 3:0.0847457 4:-4.03573e-08
2 1:0.0555554 2:-0.25 3:0.118644 4:-4.03573e-08
2 1:-0.555556 2:-0.583333 3:-0.322034 4:-0.166667
2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08
3 1:0.111111 2:0.0833333 3:0.694915 4:1
3 1:-0.166667 2:-0.416667 3:0.38983 4:0.5
3 1:0.555555 2:-0.166667 3:0.661017 4:0.666667
3 1:0.111111 2:-0.25 3:0.559322 4:0.416667
3 1:0.222222 2:-0.166667 3:0.627119 4:0.75
3 1:0.833333 2:-0.166667 3:0.898305 4:0.666667
3 1:-0.666667 2:-0.583333 3:0.186441 4:0.333333
3 1:0.666667 2:-0.25 3:0.79661 4:0.416667
3 1:0.333333 2:-0.583333 3:0.627119 4:0.416667
3 1:0.611111 2:0.333333 3:0.728813 4:1
3 1:0.222222 3:0.38983 4:0.583333
3 1:0.166667 2:-0.416667 3:0.457627 4:0.5
3 1:0.388889 2:-0.166667 3:0.525424 4:0.666667
3 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333
3 1:-0.166667 2:-0.333333 3:0.38983 4:0.916667
3 1:0.166667 3:0.457627 4:0.833333
3 1:0.222222 2:-0.166667 3:0.525424 4:0.416667
3 1:0.888889 2:0.5 3:0.932203 4:0.75
3 1:0.888889 2:-0.5 3:1 4:0.833333
3 1:-0.0555556 2:-0.833333 3:0.355932 4:0.166667
3 1:0.444444 3:0.59322 4:0.833333
3 1:-0.277778 2:-0.333333 3:0.322034 4:0.583333
3 1:0.888889 2:-0.333333 3:0.932203 4:0.583333
3 1:0.111111 2:-0.416667 3:0.322034 4:0.416667
3 1:0.333333 2:0.0833333 3:0.59322 4:0.666667
3 1:0.611111 3:0.694915 4:0.416667
3 1:0.0555554 2:-0.333333 3:0.288136 4:0.416667
3 1:-1.32455e-07 2:-0.166667 3:0.322034 4:0.416667
3 1:0.166667 2:-0.333333 3:0.559322 4:0.666667
3 1:0.611111 2:-0.166667 3:0.627119 4:0.25
3 1:0.722222 2:-0.333333 3:0.728813 4:0.5
3 1:1 2:0.5 3:0.830508 4:0.583333
3 1:0.166667 2:-0.333333 3:0.559322 4:0.75
3 1:0.111111 2:-0.333333 3:0.38983 4:0.166667
3 1:-1.32455e-07 2:-0.5 3:0.559322 4:0.0833333
3 1:0.888889 2:-0.166667 3:0.728813 4:0.833333
3 1:0.111111 2:0.166667 3:0.559322 4:0.916667
3 1:0.166667 2:-0.0833334 3:0.525424 4:0.416667
3 1:-0.0555556 2:-0.166667 3:0.288136 4:0.416667
3 1:0.444444 2:-0.0833334 3:0.491525 4:0.666667
3 1:0.333333 2:-0.0833334 3:0.559322 4:0.916667
3 1:0.444444 2:-0.0833334 3:0.38983 4:0.833333
3 1:-0.166667 2:-0.416667 3:0.38983 4:0.5
3 1:0.388889 3:0.661017 4:0.833333
3 1:0.333333 2:0.0833333 3:0.59322 4:1
3 1:0.333333 2:-0.166667 3:0.423729 4:0.833333
3 1:0.111111 2:-0.583333 3:0.355932 4:0.5
3 1:0.222222 2:-0.166667 3:0.423729 4:0.583333
3 1:0.0555554 2:0.166667 3:0.491525 4:0.833333
3 1:-0.111111 2:-0.166667 3:0.38983 4:0.416667
1:0.111111 2:-0.166667 3:0.423729 4:0.583333
階層
root/
├build.sbt
├data.scale
├iris.scale
└src
└main
└scala
└libsvmtest.scala
実行方法
sbt "run iris.scale data.scale"
libsvmのクラスについて
###svm_problem
このクラスに学習したいデータを入力していく。
svm_problemには、
svm_problem.x,svm_problem.y,svm_problem.lの3つのメソッドがある。
svm_problem.lはデータの数
svm_problem.yは全データのラベルで今回のirisのデータだと1,2,3のどれかの配列
(よく見るlibsvm形式のデータだとラベルは整数なので整数しか駄目だと思っていたのですが、どうやらArray(double)で入力できるので整数以外もラベルにできるんですね)
svm_problem.xはArray(Array[svm_node])で入力します。
svm_node
svm_nodeのクラスはsvm_node.index(Int)とsvm_node.value(double)の2つのメソッドを持っていて、indexにデータの番号を入力し、valueに値を入力します。
今回のirisデータの場合、1:-0.555556のようなデータだと1がindexで-0.555556がvalueになってます。
ラベル以外の、一行のデータはArray([svm_node])に格納しています。
注意点として、一つのデータ配列の末尾のindexに-1を入力すること
1:-0.555556 2:0.25 3:-0.864407 4:-0.916667のようなデータの場合、
array((1,-0.555556),(2,0.25),(3,-0.864407),(4,-0.916667),(-1,-1))
svm_parameter
param.svm_type=svm_parameter.C_SVC
param.kernel_type=svm_parameter.RBF
param.gamma=0.5
param.degree=3
param.coef0=0
param.nu=0.5
param.cache_size=20000
param.C=1
param.eps=0.001
param.p=0.1
svm_parameterでは名前の通り、パラメーターを指定できる。
パラメーターのノウハウは私より詳しい記事がたくさんあるのでそちらで・・・。
svmのタイプ指定とカーネルのタイプは数値でもできますが、可読性を考慮するなら文字でしていたほうが良いともいます。
typeで指定できるのは、
C_SVC,NU_SVC,ONE_CLASS,EPSILON_SVR,NU_SVR
カーネルタイプの指定は、LINEAR,POLY,RBF,SIGMOID,PRECOMPUTED
svm.svm_train(データ,パラメーター)
1つ目の引数にsvm_problemを、2つ目の引数にsvm_parameterを入力することでモデルを出してくれます。
svm.svm_predict(モデル,判別したいデータ)
1つ目の引数に先程のsvm.svm_train作られたモデルを与え、2つ目の引数に判別したいデータをsvm_nodeに入力して渡すことで判別してくれます。