0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

scala spark勉強備忘録⑤ とりあえず機械学習

Last updated at Posted at 2021-09-29

詳細は後ほど編集しなおします
とりあえずコードをこちらにうつす

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.evaluation.RegressionEvaluator

val filePath = "/home/ubuntu/LearningSparkV2/databricks-datasets/learning-spark-v2/sf-airbnb/sf-airbnb-clean.parquet/"
val airbnbDF = spark.read.parquet(filePath)
val Array(trainDF, testDF) = airbnbDF.randomSplit(Array(.8, .2), seed=42)

val vecAssembler = new VectorAssembler().setInputCols(Array("bedrooms", "bathrooms")).setOutputCol("features")
val lr = new LinearRegression().setFeaturesCol("features").setLabelCol("price")

val pipeline = new Pipeline().setStages(Array(vecAssembler, lr))
val pipelineModel = pipeline.fit(trainDF)
val predDF = pipelineModel.transform(testDF)


val categoricalCols = trainDF.dtypes.filter(_._2 == "StringType").map(_._1)
val indexOutputCols = categoricalCols.map(_ + "index")
val oheOutputCols = categoricalCols.map(_ + "OHE")

val stringIndexer = new StringIndexer().setInputCols(categoricalCols).setOutputCols(indexOutputCols).setHandleInvalid("skip")
val oheEncoder = new OneHotEncoder().setInputCols(categoricalCols).setOutputCols(oheOutputCols)
val numericCols = trainDF.dtypes.filter{ case (field, dataType) => dataType == "DoubleType" && field != "price"}.map(_._1)
val assemblerInputs = categoricalCols ++ numericCols
val vecAssembler = new VectorAssembler().setInputCols(assemblerInputs).setOutputCol("features")

val lr = new LinearRegression().setLabelCol("price").setFeaturesCol("features
")
val pipeline = new Pipeline().setStages(Array(stringIndexer, oheEncoder, vecAssembler, lr)
val pipelineModel = pipeline.fit(trainDF)
ここでなぜかエラー
ERROR Instrumentation: java.lang.IllegalArgumentException: requirement failed: Column host_is_superhost must be of type numeric but was actually of type string.
以下でも同じらしい(強制的にoheになるからそれでいい時だけっぽい、ただoheするなら以下が処理早いそう)

val rFormula = new RFormula().setFormula("price ~ .").setFeaturesCol("features").setLabelCol("price").setHandleInvalid("skip")
val pipeline = new Pipeline().setStages(Array(rFormula, lr))
val pipelineModel = pipeline.fit(trainDF)
val predDF = pipelineModel.transform(testDF)

val regressionEvaluator = new RegressionEvaluator().setPredictionCol("prediction").setLabelCol("price").setMetricName("rmse")
val rmse = regressionEvaluator.evaluate(predDF)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?