0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

scala spark勉強備忘録⑥ とりあえず決定木の回帰

Posted at

引き続きとりあえず機械学習、
今回は決定木を

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.regression.DecisionTreeRegressor

val categoricalCols = trainDF.dtypes.filter(_._2 == "StringType").map(_._1)
val indexOutputCols = categoricalCols.map(_ + "index")

val stringIndexer = new StringIndexer().setInputCols(categoricalCols).setOutputCols(indexOutputCols).setHandleInvalid("skip")
val numericCols = trainDF.dtypes.filter{ case (field, dataType) => dataType == "DoubleType" && field != "price"}.map(_._1)
val assemblerInputs = categoricalCols ++ numericCols
val vecAssembler = new VectorAssembler().setInputCols(assemblerInputs).setOutputCol("features")

val dt = new DecisionTreeRegressor().setLabelCol("price")
val stages = Array(stringIndexer, vecAssembler, dt)
val pipeline = new Pipeline().setStages(stages)
dt.setMaxBins(40)
val pipelineModel = pipeline.fit(trainDF)

val predDF = pipelineModel.transform(testDF)
val regressionEvaluator = new RegressionEvaluator().setPredictionCol("prediction").setLabelCol("price").setMetricName("rmse")
val rmse = regressionEvaluator.evaluate(predDF)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?