Spark MLがどのように分散学習を実現しているのかを最小構成のLogistic Regressionを再実装することで見ていきます.
簡単のため2クラス分類に限定します.
その後独自モデルの拡張方法について述べます.
Sparkのバージョンは2.2.0とします.
また、メンテナンスのみとなったmllib packageではなく、ml packageを対象とします。
EstimatorとTransformer
何を継承/実装すれば良いかについては,こちらは既に素晴らしい資料が存在します.
仔細は上に譲りますが,Estimatorのfitメソッドに学習データを渡すことで学習済みのモデルであるTransformerが返り,
Transformerのtransformメソッドに予測対象のデータを渡すことで予測結果が追記されたデータが返ります.
EstimatorとTransformerはともにPipelineStage のサブクラスでPipelineに組み込むことができます.
上記のModelはTransformerとなります.
Logistic Regressionがモデリングしている値は確率値なので,EstimatorのサブクラスであるProbabilisticClassifier,TransformerのサブクラスであるProbabilisticClassifierModelをそれぞれ継承してBinaryLogisticRegressionクラス,およびBinaryLogisticRegressionModelクラスを定義します.
利用したいtraitやmethodがsparkやml package限定の場合が多いのでorg.apache.spark.ml.classification以下にクラスを追加します.
package org.apache.spark.ml.classification
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
class BinaryLogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, BinaryLogisticRegression, BinaryLogisticRegressionModel] {
override def copy(extra: ParamMap): BinaryLogisticRegression = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = ???
}
class BinaryLogisticRegressionModel(override val uid: String)
extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] {
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = ???
override def numClasses: Int = 2
override protected def predictRaw(features: Vector): Vector = ???
override def copy(extra: ParamMap): BinaryLogisticRegressionModel = ???
}
ここで型パラメータである Vector
は特徴量の型です.
LogisticRegressionModel
の predict
は特徴量とcoefficients(weights)の線形和をとる処理,raw2probabilityInPlace
は線形和にシグモイド関数を適用する処理を記載します.
簡単のため,以降intercept(bias)項はcoefficientsに含めて考えます.つまり特徴量を表すVectorの最後に1.0が追加されている状態とします.
class BinaryLogisticRegressionModel(
override val uid: String,
val coefficients: Vector)
extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] {
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector =
Vectors.dense(rawPrediction.toDense.values.map(sigmoid(_)))
override def numClasses: Int = 2
override protected def predictRaw(features: Vector): Vector = {
val margin = BLAS.dot(coefficients, features)
Vectors.dense(-margin, margin)
}
override def copy(extra: ParamMap): BinaryLogisticRegressionModel = {
val model = copyValues(new BinaryLogisticRegressionModel(uid, coefficients))
model.setParent(parent)
}
BLASは numpyなどでも用いられているベクトル/行列演算のためのライブラリ(仕様)です. 1
predict
内で Vectors.dense(Array(-margin, margin))
を返すのはlabelが0の場合と1の場合の両方の確率をベクトルで保持するためです. これは org.apache.spark.ml.evaluation.BinaryClassifierEvaluator
との互換性を維持し,AUC等を計算する場合に必要になります.
あとは train
内でcoefficientsを計算し,LogisticRegressionModel
をインスタンス化して返すのみです.
学習
訓練データ集合に対するクロスエントロピー誤差関数の勾配を計算して降下方向にパラメータを更新します.
Spark MLにはDistributed Tensorflowと異なりParameter Serverが存在しないため(mini-batch)SGD系の学習よりはバッチ学習の方が向いています.
Spark ML標準のLogistic Regressionでは二次の勾配(の近似)まで考慮するL-BFGS (Lassoの場合はOWL-QN)を利用しています.
L-BFGSに関しては以下の記事がわかりやすいです.
Numerical Optimization: Understanding L-BFGS
Spark MLではL-BFGSはbreezeのコンポーネントを利用しています. 使用例は以下です.
import breeze.linalg.{DenseVector => BDV}
val trains = // 訓練データ集合
val numFeatures = // 特徴量 (+ intercept項) の数
val optimizer = new LBFGS[BDV[Double]]()
val initialCoefficients = Vectors.zeros(numFeatures)
val costFun = new BinaryLogisticLossFun(trains)
val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray))
val trainedCoefficients = x.toArray.clone
ここで costFun
は breeze.optimize.DiffFunction
を継承し自作した BinaryLogisticLossFun
のインスタンスとなります.
実装が必要なメソッドは calculate
で以下のsignatureです.
class BinaryLogisticLossFun(trains: /* データ集合の型 */) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = ???
}
calculate
はiterationの度にその時点のcoefficientsを表現したベクトルを受け取り,計算した誤差関数値と,そのcoefficientsに対する勾配ベクトルを返します.
SparkのRDDを利用しない場合は以下となります.
case class Datum(label: Double, features: Vector)
class BinaryLogisticLossFun(data: Seq[Datum]) extends DiffFunction[BDV[Double]] {
val n = data.length
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val (loss, multiplier) = data.foldLeft((0.0, 0.0)) { case ((loss, multiplier), Datum(label, features)) =>
val p = sigmoid(coefficients dot new BDV[Double](features.toArray))
(loss - (label * math.log(p) + (1.0 - label) * math.log(1.0 - p)) / n, multiplier + (p - label) / n)
}
(loss, data.foldLeft(Vectors.zeros(coefficients.size)) { case (acc, Datum(_, features)) =>
BLAS.axpy(multiplier, features, acc)
acc
}.asBreeze.toDenseVector)
}
}
Sparkを利用する場合はinstancesの型が RDD[Instance]
となります.
各partition毎に勾配値の合計を計算し,最終的に全partitionの勾配値の合計を集計します.
分散処理されるのはバッチ学習における勾配計算のみなので、非分散環境と比較して結果が変わるということはありません.
多数のpartitionをdriverにて集約する際,driverに多量の負荷がかかるため階層的に集約する機能がRDDに用意されています.
それが treeAggregate
です.
treeAggregate
にはpartition内のデータの集約方法を記述した seqOp
と,partition間の集約方法を記述した combOp
とともに BinaryLogisticAggregator
のインスタンスを渡します.
class BinaryLogisticLossFun(data: RDD[Datum]) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val bcCoefficients = data.context.broadcast(Vectors.fromBreeze(coefficients))
val logisticAggregator = {
val seqOp = (c: BinaryLogisticAggregator, instance: Datum) => c.add(instance)
val combOp = (c1: BinaryLogisticAggregator, c2: BinaryLogisticAggregator) => c1.merge(c2)
data.treeAggregate(new BinaryLogisticAggregator(bcCoefficients))(seqOp, combOp)
}
bcCoefficients.destroy(blocking = false)
(logisticAggregator.loss, BDV(logisticAggregator.gradient.toArray))
}
}
class BinaryLogisticAggregator(bcCoefficients: Broadcast[Vector]) extends Serializable {
private var numData = 0
private var lossSum = 0.0
private val numCoefficients = bcCoefficients.value.size
@transient
private lazy val coefficients: Vector = bcCoefficients.value
private lazy val localGradient = Vectors.zeros(numCoefficients)
def add(datum: Datum): BinaryLogisticAggregator = datum match {
case Datum(label, features) =>
val margin = BLAS.dot(coefficients, features)
BLAS.axpy(sigmoid(margin) - label, features, localGradient)
lossSum += label * MLUtils.log1pExp(-margin) + (1.0 - label) * (MLUtils.log1pExp(-margin) + margin)
numData += 1
this
}
def merge(other: BinaryLogisticAggregator): BinaryLogisticAggregator = {
numData += other.numData
lossSum += other.lossSum
this
}
def loss: Double = lossSum / numData
def gradient: Vector = {
val result = Vectors.dense(localGradient.toArray.clone)
BLAS.scal(1.0 / numData, result)
result
}
}
上記のとおり LogisticAggregator
の add
が seqOp
に, merge
が combOp
に対応しています.
これで BinaryLogisticRegression
の train
メソッドは以下のように書けます.
override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = {
val data = dataset.select("label", "features").rdd.map { case Row(label: Double, features: Vector) =>
Datum(label, features)
}
val numFeatures = data.first().features.size
val optimizer = new LBFGS[BDV[Double]]()
val initialCoefficients = Vectors.zeros(numFeatures)
val costFun = new BinaryLogisticLossFun(data)
val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray))
val trainedCoefficients = Vectors.dense(x.toArray.clone)
new BinaryLogisticRegressionModel(uid, trainedCoefficients)
}
}
ここまでの全てのコードを以下に記載しています.
https://gist.github.com/mrkm4ntr/a7d2093cc23d2f077c2226e2d19d0bf6
独自モデルの追加
上記の実装をよくみると,誤差関数の実装は BinaryLogisticAggregator
の add
メソッドにのみ関係することがわかります.
すなわち,とある誤差関数を定義してそれを最小化するというタイプのモデルであれば, add
メソッドの実装を変更すれば(大雑把にいうと)実現できることになります.
実際Spark 2.3.0では共通コンポーネントとして RDDLossFunction
と DifferentiableLossAggregator
クラスが追加され,DifferentiableLossAggregator
の add
メソッドを実装すればよくなっています.
例えば線形カーネルを用いたソフトマージンSVMである LinearSVC
では DifferentiableLossAggregator
を実装した HingeLossAggregator
を用いることで実現してます.2
以下調整中ですが,Delayed Feedback Model 3 の実装です.
総括
Sparkのオリジナルコードを読むと多くのエラー処理やバリデーション処理,および最適化などが含まれているのでかなりの長さになっていますが,最小構成を見るとそれほど難しいことをやっていないことがお分かりかと思います.
Spark MLはまだまだ実装されているモデルが少ないので,拡張したい場合などの参考になれば幸いです.
ここまで読んでいただきありがとうございました.
-
SparkにおいてLevel1 BLAS(ベクトル同士の演算)はJava実装が使われています.それより上はnative implementation、fallback時にJava実装が用いられます. ↩
-
ただしquadratic programmingで解ける問題をL-BFGSで解いているので収束が遅い... ↩
-
広告配信などでコンバージョンがクリックよりも遅れるので学習時に本来正例だったデータが負例として扱われる問題に対応したモデル
https://github.com/mrkm4ntr/spark-delayed-feedback-model ↩