21
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Apache SparkのMLlibを使って、RandomForestをローカル環境で試す

Last updated at Posted at 2015-02-25

Apache Spark 1.2.0で、MLlibにRandomForestが実装されていまして、分類(classification)も回帰(regression)も出来るようになっています。

SBTを使うと、比較的簡単にローカル環境で試せたので、メモしておきます。

環境

必要なのは、JDKとSBTの2つ。

OSは「Windows 8.1 64bit版」の無印(not Pro)でやってますが、特別なことはしていないので、他のOSでも動くんじゃないかなと。

JDK8

> javac -version
javac 1.8.0_25

SBT

> sbt about
[info] This is sbt 0.13.7

SBTは、msiをダウンロードして普通にインストーラーから入れたのですが (インストール先はc:\scala\sbtに変更)、環境変数PATHまで通してくれて便利でした。

SBT Download

Steps

プロジェクトフォルダを作る

cdしやすい場所にプロジェクトのフォルダを作ります。私の場合はこんな感じで。

> mkdir c:\scala\hello-rf

build.sbtを作る

「build.sbt」というファイル名のテキストファイルを作って、そこに以下の内容を記述します1

build.sbt
name := "Hello RandomForest"

version := "0.0.1"

scalaVersion := "2.10.4"

scalacOptions ++= Seq("-Xlint", "-deprecation", "-unchecked", "-feature", "-Xelide-below", "ALL")

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % "1.2.1",
  "org.apache.spark" %% "spark-mllib" % "1.2.1"
)

sparkのバージョンは1.2.1がMavenのrepoに出てるので、それを利用しています。

Maven Repository: org.apache.spark

RandomForestを呼び出すコードをScalaで作る

サンプルコードがMLlibのページにあるのですが、

Emsambles - MLlib

irisデータのほうがRとの比較とかやりやすいでしょ(あとこのサンプル、私は出力が分かりにくかった)、ってことで、以下のように変えてみました。

これをHelloRf.scalaという名前でプロジェクトフォルダ直下に置きます。

HelloRf.scala
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils

object HelloRf
{
  val sc = new SparkContext("local", "HelloRf")

  // Load and parse the data file.
  // libsvm style iris Data - http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale
  val data = MLUtils.loadLibSVMFile(sc, "iris.scale")

  // Split the data into training and test sets (30% held out for testing)
  val splits = data.randomSplit(Array(0.7, 0.3))
  val (trainingData, testData) = (splits(0), splits(1))

  def main(args: Array[String]): Unit =
  {
    trainClassifier()
    trainRegressor()
  }

  def trainClassifier() =
  {
    val startTime = System.currentTimeMillis
    
    // Train a RandomForest model.
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 4 // Iris data: 3 labels, (label + 1) value seems to be needed.
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 3 // Use more in practice.
    val featureSubsetStrategy = "auto" // Let the algorithm choose.
    val impurity = "gini"
    val maxDepth = 4 // <= 30
    val maxBins = 32

    val model = RandomForest.trainClassifier(
      trainingData, numClasses, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    // Evaluate model on test instances and compute test error
    val labelAndPreds = testData.zipWithIndex.map
    {
      case(current, index) =>
        val predictionResult = model.predict(current.features)
        (index, current.label, predictionResult, current.label == predictionResult) // Tuple
    }
    
    val execTime = System.currentTimeMillis - startTime
    
    val testDataCount = testData.count()
    val testErrCount = labelAndPreds.filter(r => !r._4).count // r._4 = 4th element of tuple (current.label == predictionResult)
    val testSuccessRate = 100 - (testErrCount.toDouble / testDataCount * 100)

    println("RfClassifier Results: " + testSuccessRate + "% numTrees: " + numTrees + " maxDepth: " + maxDepth + " execTime(msec): " + execTime)
    println("Test Data Count = " + testDataCount)
    println("Test Error Count = " + testErrCount)
    println("Test Success Rate (%) = " + testSuccessRate)
    println("Learned classification forest model:\n" + model.toDebugString)
    labelAndPreds.foreach(x => println(x))
  }

  def trainRegressor()
  {
    val startTime = System.currentTimeMillis
    
    // Train a RandomForest model.
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 0 // not used for regression
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 3 // Use more in practice.
    val featureSubsetStrategy = "auto" // Let the algorithm choose.
    val impurity = "variance"
    val maxDepth = 4
    val maxBins = 32

    val model = RandomForest.trainRegressor(
      trainingData, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    // Evaluate model on test instances and compute test error
    val labelsAndPredictions = testData.zipWithIndex.map
    {
      case(current, index) =>
        val predictionResult = model.predict(current.features)
        (index, current.label, predictionResult) // Tuple
    }
    
    val execTime = System.currentTimeMillis - startTime
    
    println("RfRegressor Results: execTime(msec): " + execTime)
    println("Learned regression forest model:\n" + model.toDebugString)
    labelsAndPredictions.foreach(x => println(x))
  }
}

まあ、ファイル名は拡張子.scalaであれば何でもいいです。ファイルを置く場所も、綺麗好きな方は、プロジェクトフォルダ直下ではなくてmavenスタイルな「src\main\scala」フォルダを作って、そこに入れたらいいと思います。

Irisデータを入手する

データはRDD<LabeledPoint>として読み込んでやればいいということなんですが、libsvm形式のirisデータがあったので、これをサンプル通りMLUtils.loadLibSVMFileで読み込むという方針で。

以下からダウンロードして、プロジェクトフォルダ直下に置きます。

LIBSVM Data - iris.scale

実行する

コマンドプロンプトで、

> cd [your project folder]
> sbt compile
> sbt run > sysout.txt

リダイレクトでtxtを吐いておくと、printlnのところだけ取れるんでわかりやすいかなと。

実際にやってみた結果がこちら。irisデータに対してClassifierの正解率94.0%は、まあ普通といったところでしょうか。Treeの数、if-elseの数と深さは、numTrees, numDepth, maxBinsを弄ると変わるんで、お試しあれってところですね。

RfClassifier Results: 94.0% numTrees: 3 maxDepth: 4 execTime(msec): 11237
Test Data Count = 50
Test Error Count = 3
Test Success Rate (%) = 94.0
Learned classification forest model:
TreeEnsembleModel classifier with 3 trees

  Tree 0:
    If (feature 2 <= -0.694915)
     Predict: 1.0
    Else (feature 2 > -0.694915)
     If (feature 3 <= 0.25)
     ...
  Tree 1:
    If (feature 0 <= -0.388889)
     If (feature 2 <= -0.762712)
      Predict: 1.0
     Else (feature 2 > -0.762712)
      If (feature 2 <= -0.152542)
       Predict: 2.0
      ...
  Tree 2:
    If (feature 2 <= -0.694915)
     Predict: 1.0
    Else (feature 2 > -0.694915)
     If (feature 3 <= 0.166667)
     ...

(0,1.0,1.0,true)
(1,1.0,1.0,true)
(2,1.0,1.0,true)
...
(8,1.0,1.0,true)
...
(36,3.0,2.0,false)
...
(43,3.0,2.0,false)
(44,3.0,2.0,false)
...
(48,3.0,3.0,true)
(49,3.0,3.0,true)

RfRegressor Results: execTime(msec): 5796
Learned regression forest model:
TreeEnsembleModel regressor with 3 trees

  Tree 0:
    If (feature 3 <= -0.583333)
     Predict: 1.0
    Else (feature 3 > -0.583333)
     If (feature 3 <= 0.25)
      If (feature 1 <= -0.5)
       If (feature 3 <= -4.03573E-8)
        Predict: 2.0
       ...
  Tree 1:
    If (feature 2 <= -0.762712)
     Predict: 1.0
    Else (feature 2 > -0.762712)
     If (feature 2 <= 0.254237)
      Predict: 2.0
     ...
  Tree 2:
    If (feature 2 <= -0.694915)
     Predict: 1.0
    Else (feature 2 > -0.694915)
     If (feature 2 <= 0.288136)
      If (feature 2 <= 0.152542)
       Predict: 2.0
      ...

(0,1.0,1.0)
(1,1.0,1.0)
(2,1.0,1.0)
...
(8,1.0,1.3333333333333333)
...
(36,3.0,2.6666666666666665)
...
(43,3.0,2.3333333333333335)
(44,3.0,2.6944444444444446)
...
(48,3.0,3.0)
(49,3.0,3.0)

Next Step

numTrees, numDepth, maxBinsの調整をして、精度と実行時間の兼ね合いを見るなど。っても150件しかないirisデータなんで、適当でいいかな感はありますが。

  1. build.sbtについては、 「Spark MLlibで相関係数を算出してみる。」を参考にさせて頂きました

21
22
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?