Sparkに依存せずにモデルを実行するMleapを使った方法の記事を記載、こっちの方がいいと思う。(2019/10/12追記)
はじめに
- 備忘録のようなクソエントリー
- Sparkの勉強がてらに作った機械学習モデルをWebAPIにするやーつ
Apache Spark MLlibで作ったモデルで予測サーバーを作りたい
-
Spark MLlib
で作った機械学習モデルをWebAPIにする実験。 - APサーバは
http4s
を利用。Sparkが動いているのもAPサーバが動いているのも同一マシン上。
build.sbt
build.sbt
name := "sparktest"
version := "0.1"
scalaVersion := "2.11.12" //Spark2.4.3だとscala2.11でないと動かなかった。
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.3"
libraryDependencies += "org.http4s" %% "http4s-dsl" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-blaze-server" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-circe" % "0.20.10"
libraryDependencies += "io.circe" %% "circe-generic" % "0.11.1"
libraryDependencies += "io.circe" %% "circe-literal" % "0.11.1"
モデル作って保存する
- とりあえず適当に機械学習モデルをSpark MLlibを用いて作成する。
- データはScikit-learnにある例のIrisのデータをCSVにしたもの
- できたモデルをWebアプリから利用するため保存しておく。
import org.apache.spark.sql.{SparkSession,Encoders}
import org.apache.spark.types.StructType
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StandardScaler,VectorAssembler}
import org.apache.spark.ml.classfication.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassficationEvaluator
import org.apache.spark.ml.tuning.{CrossValidator,CrossValidatorModel,ParamGridBuilder}
implicit val spark = SparkSession.builder().master("local[*]").getOrCreate()
case class IrisData(f0:Double,f1:Double,f2:Double,f3:Double,target:Int)
def train(path:String)(implicit spark:SparkSession):Unit ={
val schema = new StructType() //CSVファイルのスキーマ
.add("id","Int")
.add("f0","Double")
.add("f1","Double")
.add("f2","Double")
.add("f3","Double")
.add("target","Int")
import spark.implicits._
//CSVファイル読み込んでDataSetにする
val ds = spark
.read
.schema(schema)
.option("header",true)
.csv("iris.csv")
.select("f0","f1","f2","f3","target")
.as[IrisData](Encoders.product[IrisData])
//訓練データと検証データに分割
val Array(train, test) = ds.randomSplit(Array(0.8, 0.2))
//特徴ベクトルを1つの"features"列にする
val assembler = new VectorAssembler()
.setInputCols(Array("f0", "f1", "f2", "f3"))
.setOutputCol("features")
//標準化
val scaler = new StandardScaler()
.setInputCol(assembler.getOutputCol)
.setOutputCol("scaledFeatures")
//ランダムフォレスト
val randomForest = new RandomForestClassifier()
.setFeaturesCol(scaler.getOutputCol)
.setLabelCol("target")
//パイプライン
val pipeline = new Pipeline()
.setStages(Array(assembler, scaler, randomForest))
//グリッドサーチ用のハイパーパラメータ
val paramGrid = new ParamGridBuilder()
.addGrid(randomForest.maxDepth, Array(8, 16))
.addGrid(randomForest.numTrees, Array(16, 32))
.build()
//評価器
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol(randomForest.getLabelCol)
.setPredictionCol(randomForest.getPredictionCol)
//交差検定
val validator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(3)
//学習
val model = validator.fit(train)
//学習結果の保存
model.save(path)
//予測
val predict = model.transform(test)
//予測結果の正解数
val positiveCount = predict
.select("target","prediction") //正解値と予測値のみ抽出し
.map{data => if (data.getInt(0).toDouble == data.getDouble(1)) 1 else 0} //一致している場合だけ1
.reduce{_ + _} //その合計値
println("正解数:${positiveCount}\t正解率:${positiveCount / predict.count().toDouble}")
}
train("savedModel")
予測処理
-
SparkSession
をリクエスト間で共有しているけれど、本当に問題ないか?(スレッドセーフか?) - おそらく問題はないはずだがSparkSessionは使いまわしているので、
Using an existing SparkSession; some configuration may not take effect.
というウォーニングがでる。 - モデル読みを遅延評価(
lazy
)にしないと、なぜかエラーになった。
IrisPredictor.scala
import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.{Encoders, SparkSession}
case class IrisData(f0:Double,f1:Double,f2:Double,f3:Double,target:Int)
object IrisPredictor {
lazy val model = CrossValidatorModel.load("savedmodel") //なぜか遅延評価にしないとうまくいかなかった。
val masterURL = "local[*]"
//予測処理
def predict(data:IrisData):Int = {
//セッション、モデルともにスレッドセーフ?リクエストで使いまわしていいのか?
val spark = SparkSession.builder().master(masterURL).getOrCreate()
val test = spark.createDataset(Seq(data))(Encoders.product[IrisData])
model.transform(test).select("prediction").head.getDouble(0).toInt
}
def stop():Unit = SparkSession.builder().master(masterURL).getOrCreate().stop()
}
Webアプリ側の処理
- JSONで受け取ったリクエストを
ReqJson型
にデコードし - それを
IrisData型
に詰め替えて - 予測して
-
ResJson型
にして - エンコードしてJSONをレスポンスする
IrisPredictService.scala
import org.http4s.HttpRoutes
import cats.effect.IO
import org.http4s.implicits._
import org.http4s.dsl.io._
object IrisPredictService {
case class ReqJson(f0:Double,f1:Double,f2:Double,f3:Double){
def toIrisData():IrisData = IrisData(f0,f1,f2,f3,0)
}
case class ResJson(target:Int)
val service = HttpRoutes.of[IO]{
//ルーティング
case req @ POST -> Root / "iris" / "predict" =>
//リクエスト、レスポンスをデコード、エンコードするための準備
import org.http4s.circe.{jsonOf,jsonEncoderOf}
import io.circe.generic.auto._
implicit val decoder = jsonOf[IO,ReqJson]
implicit val encoder = jsonEncoderOf[IO,ResJson]
for {
inputJson <- req.as[ReqJson] //リクエストをオブジェクトに変換して
test = inputJson.toIrisData //IrisDataを作って
res <- Ok(ResJson(IrisPredictor.predict(test))) //予測結果をレスポンスにする
} yield (res)
}.orNotFound
}
サーバの起動・停止処理
-
http4s
のマニュアル通りにやっているだけで、cats effect
のモナモナしたところはイマイチわかってない。
Main.scala
import org.http4s.server.blaze.BlazeServerBuilder
import cats.effect.{ContextShift, IO, Timer}
import scala.concurrent.ExecutionContext
object Main {
def main(args: Array[String]): Unit = {
implicit val cs:ContextShift[IO] = IO.contextShift(ExecutionContext.global)
implicit val timer:Timer[IO] = IO.timer(ExecutionContext.global)
//サーバの事前準備
val serverBuilder = BlazeServerBuilder[IO]
.bindLocal(9999) // http://localhost:9999
.withHttpApp(IrisPredictService.service)
//サーバの起動
val fiber = serverBuilder.resource.use(_ => IO.never).start.unsafeRunSync()
//キーが何か押されたら、サーバ停止とSparkSession停止
scala.io.StdIn.readLine()
fiber.cancel.unsafeRunSync()
IrisPredictor.stop()
}
}