6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Spark mllibのmini batch SGD実装

Last updated at Posted at 2015-12-31

#はじめに

この記事はApache Spark Advent Calendar 2015の24日目の記事です。
Spark mllibのSVM, Logistic Regressionの学習に使われている mini batch SGD実装について調べた結果などを書きます。

SGDの実装は簡単で応用も広く有用な最適化手法なのでSparkの分散処理の枠組みの中でどう実装されているか興味を持っていました。

また、SGDの学習率を自動に最適化する手法, AdaGrad, Adam etc.をSpark mllibでどう実装できるかにも興味を持っていてその観点で調べてみました。

次の順番に軽く書いてみたいと思います。

  • GradientDescent.scalaのソースコードリーディング
  • SGD実装に関係するSpark JIRAチケット
  • AdaGradのプロトタイピング

GradientDescent.scalaのソースコードリーディング

Sparkのmini batch SGD実装は

  • SVM
  • LogisticRegression

に使われていて、その本体はGradientDescent.scalaになります。

このGradientDescent.scalaでSparkの分散処理フレームワークを使って、分散のmini batch SGDが実装されています。

ここで、GradientDescent.scalaの以下の観点で解説を行ってみたいと思います。

  • 分散でどうやってgradient updateやっているか
  • workerの結果を平均している。
  • どこで、broadcastで配って, driverで平均とっている
  • shuffleで平均している?
  • 学習係数はどう設定、更新されるのか?
  • AdaGradとかに修正実装できそうか?
  • BSPとの違いは?
  • SSPにするには何がいる?

GradientDescent.scalaのコアの処理はrunMiniBatchSGDで、

このmethod内だけで分散処理によるSGDが処理されています。
処理自体はさほど長くなく、基本的にSpark RDDによる集計処理の繰り返しのみで実現されています。

GradientDescent.scala
/**
     * For the first iteration, the regVal will be initialized as sum of weight squares
     * if it's L2 updater; for L1 updater, the same logic is followed.
     */
    var regVal = updater.compute(
      weights, Vectors.zeros(weights.size), 0, 1, regParam)._2

    var converged = false // indicates whether converged based on convergenceTol
    var i = 1
1  while (!converged && i <= numIterations) {
2    val bcWeights = data.context.broadcast(weights)
      // Sample a subset (fraction miniBatchFraction) of the total data
      // compute and sum up the subgradients on this subset (this is one map-reduce)
3      val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
        .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
          seqOp = (c, v) => {
            // c: (grad, loss, count), v: (label, features)
            val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
            (c._1, c._2 + l, c._3 + 1)
          },
          combOp = (c1, c2) => {
            // c: (grad, loss, count)
            (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
          })

      if (miniBatchSize > 0) {
        /**
         * lossSum is computed using the weights from the previous iteration
         * and regVal is the regularization value computed in the previous iteration as well.
         */
        stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
4       val update = updater.compute(
          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
          stepSize, i, regParam)
        weights = update._1
        regVal = update._2

        previousWeights = currentWeights
        currentWeights = Some(weights)
5      if (previousWeights != None && currentWeights != None) {
          converged = isConverged(previousWeights.get,
            currentWeights.get, convergenceTol)
        }
      } else {
        logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
      }
      i += 1
    }

    logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
      stochasticLossHistory.takeRight(10).mkString(", ")))

    (weights, stochasticLossHistory.toArray)

上記のソースを順に説明します。

  • ※1がメインの反復処理で、maxIteration回処理するか収束するかまで勾配の計算、更新が処理されます。

  • ※2で、currentのweight vectorがdriverから各workerにbroadcastされます。各workerで計算された勾配が次のtreeAggregationで集計されて全体の勾配になります。この集計された勾配を使ってweight vectorが更新されます。

  • ※3 broadcastされたcurrent weight vectorを使って各exampleの損失、勾配が計算されて、treeAggregationで集計されてdriverに保持されます。

  • ※4 集計された勾配を使ってweight vectorを更新します。収束判定は、lossを見ているのではなくて、weight vectorの変化量で判断しています。※5で。

というわけで、Spark mllibのSGD実装は、各workerで独立にcurrent weight vectorを使って勾配を計算して、その勾配をdriverで集計してweight vectorを計算するBSPに基づく実装になっているようです。

Updator.scala

/**
 * :: DeveloperApi ::
 * Updater for L2 regularized problems.
 *          R(w) = 1/2 ||w||^2
 * Uses a step-size decreasing with the square root of the number of iterations.
 */
@DeveloperApi
class SquaredL2Updater extends Updater {
6 override def compute(
      weightsOld: Vector,
      gradient: Vector,
      stepSize: Double,
      iter: Int,
      regParam: Double): (Vector, Double) = {
    // add up both updates from the gradient of the loss (= step) as well as
    // the gradient of the regularizer (= regParam * weightsOld)
    // w' = w - thisIterStepSize * (gradient + regParam * w)
    // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
7  val thisIterStepSize = stepSize / math.sqrt(iter)
    val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
    brzWeights :*= (1.0 - thisIterStepSize * regParam)
    brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
    val norm = brzNorm(brzWeights, 2.0)

    (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
  }
}

weight vectorの更新はUpdator.scalaで行われます。L2正則化を使う場合は
※6の部分のcompute methodで行われます。

学習率、currentStepSizeは※7において 1/sqrt(iter)で調整されます。

spark mllibのSGD実装は以上のように、RDDを使って分散処理で各事例の勾配を計算をして、その集計をRDDのtreeAggregationを使ってはいますが、基本、普通の勾配法と何も変わらない、簡単で、わかりやすい実装になっています。各反復時に、workerの勾配計算結果をすべてdriverで集計するBSP実装にもなっています。

学習率に関しては、反復数により減衰させているシンプルなもので、AdaGrad, ADAM等には比較的簡単に拡張、変更ができます。
(どこで、どのように実装するかは、プログラミング、設計的な観点でいろいろ難しい、調整、議論があるかもしれませんが。。。)

このような、シンプルでわかりやすいSGD実装ですが、収束が遅い、mini batch samplingが効率的でないなど、いろいろ言われているようで、JIRAのチケットで議論されているようです。次にこれについて見ていきます。

Spark mllibのSGD実装に関係する Jiraチケット

SGD実装に関係するJIRAチケットに下記のようなものがあります。

  1. Use faster converging optimization method in MLlib
  2. Implement Nesterov's accelerated first-order method

2.は、普通の勾配法よりも速く収束するNesterov's accelerated first-order methodを使う実装でSGD, L-BFGSの実装を置き換えましょうという話。昨年の6月で議論が止まっている。pull-reqも出されているので手法を調査してみたいと思っているが見れてない。

3. はmini batch samplingする際にRDDのiterator interfaceを使っているので全なめするので?効率的でないからもっと効率的にできないか?というチケット

2.は今のSGD実装はBSPベースなのでSSPのような感じでもうちょっと効率化、高速化できんじゃないのという話? local updateとDistBriefみたいにするという実装は具体的に書いてないので詳細は不明

勾配法の実装も、BSPベースの分散勾配法の実装も簡単なので、上記のチケットもすんなり進むと思っていたが、どれも止まっているようです。

Spark mllib 教師あり学習アルゴリズム精度測定にも書いたが、現状のSGD実装を使ったSVM, Logistic Regressionの精度はL-BFGSによるLogisitic Regressionに比べて著しく悪かった。こSGDを使った場合に精度が下がる問題に手をつけられてないので、先に進めていないのだが、上のチケットに関係するSGDの拡張、improveするような何かをしたいなどと思ってます。

AdaGrad プロトタイピング

AdaGradはSGDのをweight vectorの要素毎の更新率をの勾配の履歴から計算する方法です。

\begin{equation}
{\bf w}^{(t+1)} = {\bf w}^{(t)} - \alpha \ {\bf k}^{(t)} \circ {\bf g}^{(t)}
\end{equation}
\begin{equation}
({\bf k}^{(t)})_a = 
\frac{1}{ \sqrt {\epsilon + \sum_{k=1}^t |g_a^{(k)}|^2 } }
\end{equation}

ここで、○はvectorのelement wiseの積。添字のaは成分を表す。

Spark mllibでの実装もごちゃごちゃ言わなければ(綺麗とか、ポリシーとか、いろいろ)非常に簡単にできると思われます。

で実装してその効果を共有したいとおもっていたのですが、、、もともとのSGDそのものの精度、収束がちょっとおかしくて、AdaGradによる効果を測定、評価ができていない状況です。。。

他の分類アルゴリズムのオリジナル実装をSpark上でも実装しなおしたいと思っていますので、その時に、SGD実装の妥当性とAdaGrad等の学習率の自動調整アルゴリズムの効果を評価したいと思っています。

#さいごに

Spark mllibのSGD実装に調べてみました。前から概要は理解していたのですが、詳細を見ることで

  • BSPと同じ。workerで各事例毎に勾配を計算して、driverで集計(super iteration)weight vectorを更新する。
  • 学習率の更新はUpdatorでやっている。AdaGrad etc.への変更は簡易にできそう
  • 勾配の集計はRDD treeAggregationを使って実装されている。難しいことは何もしてない。reduceしているだけ
  • 分散でSGDなので、各workerで勾配、weight vectorを更新して、driverでweight vectorの平均をとっていると勘違いしていたがそうではない。
  • SGDできているならonline学習もそのまま行けると思ったが効率的な分散処理だと自明じゃないっぽい。そこはちょっとまじめに考えたい
  • Asynchronous Complex Analytics in a Distributed
    Dataflow Architecture (Extended Abstract)
    とか、SSPとかを真面目に考えないといけない

というようなことがわかりました。この辺メインに考えたいなと思うこのごろです。

P.S. 的な

書いて最後にさらに備忘録的に追記すると

online学習とか、SSPとかやろうとするとSpark、RDDを使ってやるのと
素でAkkaを使って自前実装するのとどっちが良いか、どっちが簡単で開発しやすいかと思う。

未だにSpark上で開発するのは効率的に行えてないので。。。

Akkaを使うシステムのスケルトンを作ってしまって、
worker上のマシンがローカルのデータにアクセスするライブラリを作ってしまえば
Akka上で実装するほうが柔軟な気もしないでもないこのごろ。。。

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?