LoginSignup
1
1

More than 5 years have passed since last update.

spark naive bayes 実験メモ

Last updated at Posted at 2015-10-01
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.feature.HashingTF

val htf = new HashingTF(10000);

val pos_data = sc.textFile("test_pos.txt").map { text => new LabeledPoint(0, htf.transform(text.split("\\s+")))};
val neg_data = sc.textFile("test_neg.txt").map { text => new LabeledPoint(1, htf.transform(text.split("\\s+")))};

var data = pos_data.union(neg_data);
var splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)

val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial")

var result = test.map { t =>
        val predicted = model.predict(t.features)

        (predicted, t.label) match {
          case (0.0, 0.0) => "TN"
          case (0.0, 1.0) => "FN"
          case (1.0, 0.0) => "FP"
          case (1.0, 1.0) => "TP"
        }
}.countByValue()

val totalCount = test.count()

val truePositiveCount = if(result.contains("TP")) result("TP").toDouble else 0;
val trueNegativeCount = if(result.contains("TN")) result("TN").toDouble else 0;
val falsePositiveCount = if(result.contains("FP")) result("FP").toDouble else 0;
val falseNegativeCount = if(result.contains("FN")) result("FN").toDouble else 0;

val accuracy = (truePositiveCount + trueNegativeCount) / totalCount
var threatscore = truePositiveCount / (truePositiveCount + falsePositiveCount + falseNegativeCount);
var percision = truePositiveCount/(truePositiveCount + falsePositiveCount);
var recall = truePositiveCount / (truePositiveCount + falseNegativeCount);
var f = truePositiveCount /( truePositiveCount + (falsePositiveCount + falseNegativeCount) /2)

println("accuracy: " + accuracy)
println("threatscore: " + threatscore)
println("percision: " + percision)
println("recall: " + recall )
println("f: " + f)

test.collect.foreach { t =>
        val predicted = model.predict(t.features);
        println(t.label +": "+ predicted)
}
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1