Scala
Spark
ML

[Spark ML] Logistic Regression詳解および独自モデルの追加について

Spark MLがどのように分散学習を実現しているのかを最小構成のLogistic Regressionを再実装することで見ていきます.
簡単のため2クラス分類に限定します.
その後独自モデルの拡張方法について述べます.

Sparkのバージョンは2.2.0とします.
また、メンテナンスのみとなったmllib packageではなく、ml packageを対象とします。

EstimatorとTransformer

何を継承/実装すれば良いかについては,こちらは既に素晴らしい資料が存在します.
* Extend Spark ML for your own model/Transformer types
* spark.mlのAPIでXGBoostを扱いたい!

仔細は上に譲りますが,Estimatorのfitメソッドに学習データを渡すことで学習済みのモデルであるTransformerが返り,
Transformerのtransformメソッドに予測対象のデータを渡すことで予測結果が追記されたデータが返ります.
EstimatorとTransformerはともにPipelineStage のサブクラスでPipelineに組み込むことができます.


(引用元: https://jaceklaskowski.gitbooks.io/mastering-apache-spark/spark-mllib/spark-mllib-pipelines.html)

上記のModelはTransformerとなります.

Logistic Regressionがモデリングしている値は確率値なので,EstimatorのサブクラスであるProbabilisticClassifier,TransformerのサブクラスであるProbabilisticClassifierModelをそれぞれ継承してBinaryLogisticRegressionクラス,およびBinaryLogisticRegressionModelクラスを定義します.

利用したいtraitやmethodがsparkやml package限定の場合が多いのでorg.apache.spark.ml.classification以下にクラスを追加します.

package org.apache.spark.ml.classification

import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset

class BinaryLogisticRegression(override val uid: String)
  extends ProbabilisticClassifier[Vector, BinaryLogisticRegression, BinaryLogisticRegressionModel] {

  override def copy(extra: ParamMap): BinaryLogisticRegression = defaultCopy(extra)

  override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = ???
}

class BinaryLogisticRegressionModel(override val uid: String)
  extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] {

  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = ???

  override def numClasses: Int = 2

  override protected def predictRaw(features: Vector): Vector = ???

  override def copy(extra: ParamMap): BinaryLogisticRegressionModel = ???
}

ここで型パラメータである Vector は特徴量の型です.
LogisticRegressionModelpredict は特徴量とcoefficients(weights)の線形和をとる処理,raw2probabilityInPlace は線形和にシグモイド関数を適用する処理を記載します.
簡単のため,以降intercept(bias)項はcoefficientsに含めて考えます.つまり特徴量を表すVectorの最後に1.0が追加されている状態とします.

class BinaryLogisticRegressionModel(
  override val uid: String,
  val coefficients: Vector)
  extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] {

  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector =
    Vectors.dense(rawPrediction.toDense.values.map(sigmoid(_)))

  override def numClasses: Int = 2

  override protected def predictRaw(features: Vector): Vector = {
    val margin = BLAS.dot(coefficients, features)
    Vectors.dense(-margin, margin)
  }

  override def copy(extra: ParamMap): BinaryLogisticRegressionModel = {
    val model = copyValues(new BinaryLogisticRegressionModel(uid, coefficients))
    model.setParent(parent)
}

BLASは numpyなどでも用いられているベクトル/行列演算のためのライブラリ(仕様)です. 1
predict 内で Vectors.dense(Array(-margin, margin)) を返すのはlabelが0の場合と1の場合の両方の確率をベクトルで保持するためです. これは org.apache.spark.ml.evaluation.BinaryClassifierEvaluator との互換性を維持し,AUC等を計算する場合に必要になります.

あとは train 内でcoefficientsを計算し,LogisticRegressionModel をインスタンス化して返すのみです.

学習

訓練データ集合に対するクロスエントロピー誤差関数の勾配を計算して降下方向にパラメータを更新します.
Spark MLにはDistributed Tensorflowと異なりParameter Serverが存在しないため(mini-batch)SGD系の学習よりはバッチ学習の方が向いています.
Spark ML標準のLogistic Regressionでは二次の勾配(の近似)まで考慮するL-BFGS (Lassoの場合はOWL-QN)を利用しています.

L-BFGSに関しては以下の記事がわかりやすいです.

Numerical Optimization: Understanding L-BFGS

Spark MLではL-BFGSはbreezeのコンポーネントを利用しています. 使用例は以下です.

import breeze.linalg.{DenseVector => BDV}

val trains = // 訓練データ集合
val numFeatures = // 特徴量 (+ intercept項) の数
val optimizer = new LBFGS[BDV[Double]]()
val initialCoefficients = Vectors.zeros(numFeatures)
val costFun = new BinaryLogisticLossFun(trains)
val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray))
val trainedCoefficients = x.toArray.clone

ここで costFunbreeze.optimize.DiffFunction を継承し自作した BinaryLogisticLossFun のインスタンスとなります.
実装が必要なメソッドは calculate で以下のsignatureです.

class BinaryLogisticLossFun(trains: /* データ集合の型 */) extends DiffFunction[BDV[Double]] {

  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = ???
}

calculate はiterationの度にその時点のcoefficientsを表現したベクトルを受け取り,計算した誤差関数値と,そのcoefficientsに対する勾配ベクトルを返します.

SparkのRDDを利用しない場合は以下となります.

case class Datum(label: Double, features: Vector)

class BinaryLogisticLossFun(data: Seq[Datum]) extends DiffFunction[BDV[Double]] {

  val n = data.length

  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
    val (loss, multiplier) = data.foldLeft((0.0, 0.0)) { case ((loss, multiplier), Datum(label, features)) =>
      val p = sigmoid(coefficients dot new BDV[Double](features.toArray))
      (loss - (label * math.log(p) + (1.0 - label) * math.log(1.0 - p)) / n, multiplier + (p - label) / n)
    }
    (loss, data.foldLeft(Vectors.zeros(coefficients.size)) { case (acc, Datum(_, features)) =>
      BLAS.axpy(multiplier, features, acc)
      acc
    }.asBreeze.toDenseVector)
  }
}

Sparkを利用する場合はinstancesの型が RDD[Instance] となります.
各partition毎に勾配値の合計を計算し,最終的に全partitionの勾配値の合計を集計します.
分散処理されるのはバッチ学習における勾配計算のみなので、非分散環境と比較して結果が変わるということはありません.
多数のpartitionをdriverにて集約する際,driverに多量の負荷がかかるため階層的に集約する機能がRDDに用意されています.
それが treeAggregate です.
treeAggregate にはpartition内のデータの集約方法を記述した seqOp と,partition間の集約方法を記述した combOp とともに BinaryLogisticAggregator のインスタンスを渡します.

class BinaryLogisticLossFun(data: RDD[Datum]) extends DiffFunction[BDV[Double]] {

  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
    val bcCoefficients = data.context.broadcast(Vectors.fromBreeze(coefficients))

    val logisticAggregator = {
      val seqOp = (c: BinaryLogisticAggregator, instance: Datum) => c.add(instance)
      val combOp = (c1: BinaryLogisticAggregator, c2: BinaryLogisticAggregator) => c1.merge(c2)
      data.treeAggregate(new BinaryLogisticAggregator(bcCoefficients))(seqOp, combOp)
    }

    bcCoefficients.destroy(blocking = false)

    (logisticAggregator.loss, BDV(logisticAggregator.gradient.toArray))
  }
}

class BinaryLogisticAggregator(bcCoefficients: Broadcast[Vector]) extends Serializable {

  private var numData = 0
  private var lossSum = 0.0

  private val numCoefficients = bcCoefficients.value.size

  @transient
  private lazy val coefficients: Vector = bcCoefficients.value
  private lazy val localGradient = Vectors.zeros(numCoefficients)

  def add(datum: Datum): BinaryLogisticAggregator = datum match {
    case Datum(label, features) =>
      val margin = BLAS.dot(coefficients, features)
      BLAS.axpy(sigmoid(margin) - label, features, localGradient)
      lossSum += label * MLUtils.log1pExp(-margin) + (1.0 - label) * (MLUtils.log1pExp(-margin) + margin)
      numData += 1
      this
  }

  def merge(other: BinaryLogisticAggregator): BinaryLogisticAggregator  = {
    numData += other.numData
    lossSum += other.lossSum
    this
  }

  def loss: Double = lossSum / numData

  def gradient: Vector = {
    val result = Vectors.dense(localGradient.toArray.clone)
    BLAS.scal(1.0 / numData, result)
    result
  }
}

上記のとおり LogisticAggregatoraddseqOp に, mergecombOp に対応しています.

これで BinaryLogisticRegressiontrain メソッドは以下のように書けます.

override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = {
    val data = dataset.select("label", "features").rdd.map { case Row(label: Double, features: Vector) =>
      Datum(label, features)
    }
    val numFeatures = data.first().features.size

    val optimizer = new LBFGS[BDV[Double]]()
    val initialCoefficients = Vectors.zeros(numFeatures)
    val costFun = new BinaryLogisticLossFun(data)
    val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray))
    val trainedCoefficients = Vectors.dense(x.toArray.clone)
    new BinaryLogisticRegressionModel(uid, trainedCoefficients)
  }
}

ここまでの全てのコードを以下に記載しています.
https://gist.github.com/mrkm4ntr/a7d2093cc23d2f077c2226e2d19d0bf6

独自モデルの追加

上記の実装をよくみると,誤差関数の実装は BinaryLogisticAggregatoradd メソッドにのみ関係することがわかります.
すなわち,とある誤差関数を定義してそれを最小化するというタイプのモデルであれば, add メソッドの実装を変更すれば(大雑把にいうと)実現できることになります.
実際Spark 2.3.0では共通コンポーネントとして RDDLossFunctionDifferentiableLossAggregator クラスが追加され,DifferentiableLossAggregatoradd メソッドを実装すればよくなっています.
例えば線形カーネルを用いたソフトマージンSVMである LinearSVC では DifferentiableLossAggregator を実装した HingeLossAggregator を用いることで実現してます.2
以下調整中ですが,Delayed Feedback Model 3 の実装です.
https://github.com/mrkm4ntr/spark-delayed-feedback-model

総括

Sparkのオリジナルコードを読むと多くのエラー処理やバリデーション処理,および最適化などが含まれているのでかなりの長さになっていますが,最小構成を見るとそれほど難しいことをやっていないことがお分かりかと思います.
Spark MLはまだまだ実装されているモデルが少ないので,拡張したい場合などの参考になれば幸いです.
ここまで読んでいただきありがとうございました.


  1. SparkにおいてLevel1 BLAS(ベクトル同士の演算)はJava実装が使われています.それより上はnative implementation、fallback時にJava実装が用いられます. 

  2. ただしquadratic programmingで解ける問題をL-BFGSで解いているので収束が遅い... 

  3. 広告配信などでコンバージョンがクリックよりも遅れるので学習時に本来正例だったデータが負例として扱われる問題に対応したモデル