Apache Spark 1.2.0で、MLlibにRandomForestが実装されていまして、分類(classification)も回帰(regression)も出来るようになっています。
SBTを使うと、比較的簡単にローカル環境で試せたので、メモしておきます。
環境
必要なのは、JDKとSBTの2つ。
OSは「Windows 8.1 64bit版」の無印(not Pro)でやってますが、特別なことはしていないので、他のOSでも動くんじゃないかなと。
JDK8
> javac -version
javac 1.8.0_25
SBT
> sbt about
[info] This is sbt 0.13.7
SBTは、msiをダウンロードして普通にインストーラーから入れたのですが (インストール先はc:\scala\sbtに変更)、環境変数PATHまで通してくれて便利でした。
Steps
プロジェクトフォルダを作る
cdしやすい場所にプロジェクトのフォルダを作ります。私の場合はこんな感じで。
> mkdir c:\scala\hello-rf
build.sbtを作る
「build.sbt」というファイル名のテキストファイルを作って、そこに以下の内容を記述します1。
name := "Hello RandomForest"
version := "0.0.1"
scalaVersion := "2.10.4"
scalacOptions ++= Seq("-Xlint", "-deprecation", "-unchecked", "-feature", "-Xelide-below", "ALL")
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "1.2.1",
"org.apache.spark" %% "spark-mllib" % "1.2.1"
)
sparkのバージョンは1.2.1がMavenのrepoに出てるので、それを利用しています。
Maven Repository: org.apache.spark
RandomForestを呼び出すコードをScalaで作る
サンプルコードがMLlibのページにあるのですが、
irisデータのほうがRとの比較とかやりやすいでしょ(あとこのサンプル、私は出力が分かりにくかった)、ってことで、以下のように変えてみました。
これをHelloRf.scalaという名前でプロジェクトフォルダ直下に置きます。
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils
object HelloRf
{
val sc = new SparkContext("local", "HelloRf")
// Load and parse the data file.
// libsvm style iris Data - http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale
val data = MLUtils.loadLibSVMFile(sc, "iris.scale")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
def main(args: Array[String]): Unit =
{
trainClassifier()
trainRegressor()
}
def trainClassifier() =
{
val startTime = System.currentTimeMillis
// Train a RandomForest model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 4 // Iris data: 3 labels, (label + 1) value seems to be needed.
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3 // Use more in practice.
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "gini"
val maxDepth = 4 // <= 30
val maxBins = 32
val model = RandomForest.trainClassifier(
trainingData, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelAndPreds = testData.zipWithIndex.map
{
case(current, index) =>
val predictionResult = model.predict(current.features)
(index, current.label, predictionResult, current.label == predictionResult) // Tuple
}
val execTime = System.currentTimeMillis - startTime
val testDataCount = testData.count()
val testErrCount = labelAndPreds.filter(r => !r._4).count // r._4 = 4th element of tuple (current.label == predictionResult)
val testSuccessRate = 100 - (testErrCount.toDouble / testDataCount * 100)
println("RfClassifier Results: " + testSuccessRate + "% numTrees: " + numTrees + " maxDepth: " + maxDepth + " execTime(msec): " + execTime)
println("Test Data Count = " + testDataCount)
println("Test Error Count = " + testErrCount)
println("Test Success Rate (%) = " + testSuccessRate)
println("Learned classification forest model:\n" + model.toDebugString)
labelAndPreds.foreach(x => println(x))
}
def trainRegressor()
{
val startTime = System.currentTimeMillis
// Train a RandomForest model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 0 // not used for regression
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3 // Use more in practice.
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "variance"
val maxDepth = 4
val maxBins = 32
val model = RandomForest.trainRegressor(
trainingData, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.zipWithIndex.map
{
case(current, index) =>
val predictionResult = model.predict(current.features)
(index, current.label, predictionResult) // Tuple
}
val execTime = System.currentTimeMillis - startTime
println("RfRegressor Results: execTime(msec): " + execTime)
println("Learned regression forest model:\n" + model.toDebugString)
labelsAndPredictions.foreach(x => println(x))
}
}
まあ、ファイル名は拡張子.scalaであれば何でもいいです。ファイルを置く場所も、綺麗好きな方は、プロジェクトフォルダ直下ではなくてmavenスタイルな「src\main\scala」フォルダを作って、そこに入れたらいいと思います。
Irisデータを入手する
データはRDD<LabeledPoint>
として読み込んでやればいいということなんですが、libsvm形式のirisデータがあったので、これをサンプル通りMLUtils.loadLibSVMFile
で読み込むという方針で。
以下からダウンロードして、プロジェクトフォルダ直下に置きます。
実行する
コマンドプロンプトで、
> cd [your project folder]
> sbt compile
> sbt run > sysout.txt
リダイレクトでtxtを吐いておくと、printlnのところだけ取れるんでわかりやすいかなと。
実際にやってみた結果がこちら。irisデータに対してClassifierの正解率94.0%は、まあ普通といったところでしょうか。Treeの数、if-elseの数と深さは、numTrees, numDepth, maxBinsを弄ると変わるんで、お試しあれってところですね。
RfClassifier Results: 94.0% numTrees: 3 maxDepth: 4 execTime(msec): 11237
Test Data Count = 50
Test Error Count = 3
Test Success Rate (%) = 94.0
Learned classification forest model:
TreeEnsembleModel classifier with 3 trees
Tree 0:
If (feature 2 <= -0.694915)
Predict: 1.0
Else (feature 2 > -0.694915)
If (feature 3 <= 0.25)
...
Tree 1:
If (feature 0 <= -0.388889)
If (feature 2 <= -0.762712)
Predict: 1.0
Else (feature 2 > -0.762712)
If (feature 2 <= -0.152542)
Predict: 2.0
...
Tree 2:
If (feature 2 <= -0.694915)
Predict: 1.0
Else (feature 2 > -0.694915)
If (feature 3 <= 0.166667)
...
(0,1.0,1.0,true)
(1,1.0,1.0,true)
(2,1.0,1.0,true)
...
(8,1.0,1.0,true)
...
(36,3.0,2.0,false)
...
(43,3.0,2.0,false)
(44,3.0,2.0,false)
...
(48,3.0,3.0,true)
(49,3.0,3.0,true)
RfRegressor Results: execTime(msec): 5796
Learned regression forest model:
TreeEnsembleModel regressor with 3 trees
Tree 0:
If (feature 3 <= -0.583333)
Predict: 1.0
Else (feature 3 > -0.583333)
If (feature 3 <= 0.25)
If (feature 1 <= -0.5)
If (feature 3 <= -4.03573E-8)
Predict: 2.0
...
Tree 1:
If (feature 2 <= -0.762712)
Predict: 1.0
Else (feature 2 > -0.762712)
If (feature 2 <= 0.254237)
Predict: 2.0
...
Tree 2:
If (feature 2 <= -0.694915)
Predict: 1.0
Else (feature 2 > -0.694915)
If (feature 2 <= 0.288136)
If (feature 2 <= 0.152542)
Predict: 2.0
...
(0,1.0,1.0)
(1,1.0,1.0)
(2,1.0,1.0)
...
(8,1.0,1.3333333333333333)
...
(36,3.0,2.6666666666666665)
...
(43,3.0,2.3333333333333335)
(44,3.0,2.6944444444444446)
...
(48,3.0,3.0)
(49,3.0,3.0)
Next Step
numTrees, numDepth, maxBinsの調整をして、精度と実行時間の兼ね合いを見るなど。っても150件しかないirisデータなんで、適当でいいかな感はありますが。
-
build.sbtについては、 「Spark MLlibで相関係数を算出してみる。」を参考にさせて頂きました ↩