LoginSignup
1
0

More than 3 years have passed since last update.

Spark機械学習モデルのWebAPI化メモ

Last updated at Posted at 2019-08-29

Sparkに依存せずにモデルを実行するMleapを使った方法の記事を記載、こっちの方がいいと思う。(2019/10/12追記)

はじめに

  • 備忘録のようなクソエントリー
  • Sparkの勉強がてらに作った機械学習モデルをWebAPIにするやーつ

Apache Spark MLlibで作ったモデルで予測サーバーを作りたい

  • Spark MLlibで作った機械学習モデルをWebAPIにする実験。
  • APサーバはhttp4sを利用。Sparkが動いているのもAPサーバが動いているのも同一マシン上。

build.sbt

build.sbt
name := "sparktest"
version := "0.1"

scalaVersion := "2.11.12" //Spark2.4.3だとscala2.11でないと動かなかった。
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.3"

libraryDependencies += "org.http4s" %% "http4s-dsl" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-blaze-server" % "0.20.10"

libraryDependencies += "org.http4s" %% "http4s-circe" % "0.20.10"
libraryDependencies += "io.circe" %% "circe-generic" % "0.11.1"
libraryDependencies += "io.circe" %% "circe-literal" % "0.11.1"

モデル作って保存する

  • とりあえず適当に機械学習モデルをSpark MLlibを用いて作成する。
  • データはScikit-learnにある例のIrisのデータをCSVにしたもの
  • できたモデルをWebアプリから利用するため保存しておく。
import org.apache.spark.sql.{SparkSession,Encoders}
import org.apache.spark.types.StructType
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StandardScaler,VectorAssembler}
import org.apache.spark.ml.classfication.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassficationEvaluator
import org.apache.spark.ml.tuning.{CrossValidator,CrossValidatorModel,ParamGridBuilder}

implicit val spark = SparkSession.builder().master("local[*]").getOrCreate()

case class IrisData(f0:Double,f1:Double,f2:Double,f3:Double,target:Int)

def train(path:String)(implicit spark:SparkSession):Unit ={
  val schema = new StructType()  //CSVファイルのスキーマ
    .add("id","Int")
    .add("f0","Double")
    .add("f1","Double")
    .add("f2","Double")
    .add("f3","Double")
    .add("target","Int")

  import spark.implicits._

  //CSVファイル読み込んでDataSetにする
  val ds = spark
    .read
    .schema(schema)
    .option("header",true)
    .csv("iris.csv")
    .select("f0","f1","f2","f3","target")
    .as[IrisData](Encoders.product[IrisData])

  //訓練データと検証データに分割
  val Array(train, test) = ds.randomSplit(Array(0.8, 0.2))

  //特徴ベクトルを1つの"features"列にする
  val assembler = new VectorAssembler()
    .setInputCols(Array("f0", "f1", "f2", "f3"))
    .setOutputCol("features")

  //標準化
  val scaler = new StandardScaler()
    .setInputCol(assembler.getOutputCol)
    .setOutputCol("scaledFeatures")

  //ランダムフォレスト
  val randomForest = new RandomForestClassifier()
    .setFeaturesCol(scaler.getOutputCol)
    .setLabelCol("target")


  //パイプライン
  val pipeline = new Pipeline()
    .setStages(Array(assembler, scaler, randomForest))


  //グリッドサーチ用のハイパーパラメータ
  val paramGrid = new ParamGridBuilder()
    .addGrid(randomForest.maxDepth, Array(8, 16))
    .addGrid(randomForest.numTrees, Array(16, 32))
    .build()

  //評価器
  val evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol(randomForest.getLabelCol)
    .setPredictionCol(randomForest.getPredictionCol)

  //交差検定
  val validator = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(evaluator)
    .setEstimatorParamMaps(paramGrid)
    .setNumFolds(3)

  //学習
  val model = validator.fit(train)

  //学習結果の保存
  model.save(path)


  //予測
  val predict = model.transform(test)

  //予測結果の正解数
  val positiveCount = predict
    .select("target","prediction") //正解値と予測値のみ抽出し
    .map{data => if (data.getInt(0).toDouble == data.getDouble(1)) 1 else 0} //一致している場合だけ1
    .reduce{_ + _} //その合計値

  println("正解数:${positiveCount}\t正解率:${positiveCount / predict.count().toDouble}")
}

train("savedModel")

予測処理

  • SparkSessionをリクエスト間で共有しているけれど、本当に問題ないか?(スレッドセーフか?)
  • おそらく問題はないはずだがSparkSessionは使いまわしているので、Using an existing SparkSession; some configuration may not take effect.というウォーニングがでる。
  • モデル読みを遅延評価(lazy)にしないと、なぜかエラーになった。
IrisPredictor.scala
import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.{Encoders, SparkSession}

case class IrisData(f0:Double,f1:Double,f2:Double,f3:Double,target:Int)

object IrisPredictor {
  lazy val model = CrossValidatorModel.load("savedmodel") //なぜか遅延評価にしないとうまくいかなかった。
  val masterURL = "local[*]"
  //予測処理
  def predict(data:IrisData):Int = {
    //セッション、モデルともにスレッドセーフ?リクエストで使いまわしていいのか?
    val spark = SparkSession.builder().master(masterURL).getOrCreate()
    val test = spark.createDataset(Seq(data))(Encoders.product[IrisData])
    model.transform(test).select("prediction").head.getDouble(0).toInt
  }

  def stop():Unit = SparkSession.builder().master(masterURL).getOrCreate().stop()
}

Webアプリ側の処理

  • JSONで受け取ったリクエストをReqJson型にデコードし
  • それをIrisData型 に詰め替えて
  • 予測して
  • ResJson型にして
  • エンコードしてJSONをレスポンスする
IrisPredictService.scala
import org.http4s.HttpRoutes
import cats.effect.IO
import org.http4s.implicits._
import org.http4s.dsl.io._

object IrisPredictService {
  case class ReqJson(f0:Double,f1:Double,f2:Double,f3:Double){
    def toIrisData():IrisData = IrisData(f0,f1,f2,f3,0)
  }
  case class ResJson(target:Int)

  val service = HttpRoutes.of[IO]{
    //ルーティング
    case req @ POST -> Root / "iris" / "predict" =>

      //リクエスト、レスポンスをデコード、エンコードするための準備
      import org.http4s.circe.{jsonOf,jsonEncoderOf}
      import io.circe.generic.auto._
      implicit val decoder = jsonOf[IO,ReqJson]
      implicit val encoder = jsonEncoderOf[IO,ResJson]

      for {
        inputJson <- req.as[ReqJson] //リクエストをオブジェクトに変換して
        test = inputJson.toIrisData //IrisDataを作って
        res <- Ok(ResJson(IrisPredictor.predict(test))) //予測結果をレスポンスにする
      } yield (res)
  }.orNotFound
}

サーバの起動・停止処理

  • http4sのマニュアル通りにやっているだけで、cats effectのモナモナしたところはイマイチわかってない。
Main.scala
import org.http4s.server.blaze.BlazeServerBuilder
import cats.effect.{ContextShift, IO, Timer}
import scala.concurrent.ExecutionContext

object Main {
  def main(args: Array[String]): Unit = {

    implicit val cs:ContextShift[IO] = IO.contextShift(ExecutionContext.global)
    implicit val timer:Timer[IO] = IO.timer(ExecutionContext.global)

    //サーバの事前準備
    val serverBuilder = BlazeServerBuilder[IO]
      .bindLocal(9999) // http://localhost:9999
      .withHttpApp(IrisPredictService.service)

    //サーバの起動
    val fiber = serverBuilder.resource.use(_ => IO.never).start.unsafeRunSync()

    //キーが何か押されたら、サーバ停止とSparkSession停止
    scala.io.StdIn.readLine()
    fiber.cancel.unsafeRunSync()
    IrisPredictor.stop()
  }
}

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0