LoginSignup
2
0

More than 5 years have passed since last update.

scala Breezeでlogistic regressionを解く

Posted at

はじめに

scalanlp/breeze: Breeze is a numerical processing library for Scala.はScalaで数値解析を行うためのライブラリです。

勉強がてらLogistic regression - Wikipediaを解くコードをBreezeで書いてみます。

今回実装するアルゴリズムはWikipediaのLogistic regressionのページでIteratively reweighted leeast squares(IRLS)と呼ばれているアルゴリズムです。

IRLS

推定したい$n+1$次元パラメータのベクタを$w^T = [\beta_0, \beta_1, ..., \beta_n]$、独立変数$x(i) = [1, x_1(i), ..., x_n(i)]^T$, ベルヌーイ分布の期待値$\mu(i) = \frac{1}{1 + e^{-w^Tx(i)}}$としたとき、以下の反復法でパラメータを推定することができます。

w_{k+1} = (X^TS_kX)^{-1}X^T(S_kXw_k + y - \mu_k) \tag{1}

ここで、$S_k$は対角要素が$\mu_k(i)(1 - \mu_k(i))$の対角行列、$\mu_k = [\mu(1), \mu(2), ...]^T$期待値のベクタ、

X = \begin{pmatrix}
1 & x_1(1) & ... & x_n(1) \\
1 & x_1(2) & ... & x_n(2) \\
\vdots & \vdots & & \vdots \\
\end{pmatrix}

は計画行列、 $y = [y(1), y(2), ...]$応答変数のベクタです。

上式の中で$i$は観測データのインデックスです。

実装

IDEAのWorksheet上でIRLSを実装してみたいと思います。

動作確認のためのデータはUCLAのコースから持ってきます。
https://stats.idre.ucla.edu/stat/data/binary.csv

ロジックのメインは$(1)$を実装している$step$関数と収束させるための$while$文です。$Pr$や$likelyhood$は発散回避用チェックに使用する確率・尤度を計算するための関数です。

import java.io.File
import breeze.linalg._

val file = new File("./data/binary.csv")
val mat = csvread(file, skipLines = 1)
val X = DenseMatrix.horzcat[DenseMatrix[Double], Double](DenseMatrix.ones[Double](mat(::, 1 to 3).rows, 1), mat(::, 1 to 3))
val y = mat(::, 0)

def u(w: DenseVector[Double]): DenseVector[Double] = X(*, ::).map(x =>  1/(1 + Math.exp(-(w dot x))))

def S(w: DenseVector[Double]) = diag(u(w).map(e => e * (1.0 - e)))

def step(Sk: DenseMatrix[Double], uk: DenseVector[Double], wk: DenseVector[Double]): DenseVector[Double] =
  inv(X.t * Sk * X) * X.t * (Sk * X * wk + y - uk)

def Pr(y: Int, x: DenseVector[Double], w: DenseVector[Double]) = {
  def h(x: DenseVector[Double], w: DenseVector[Double]) = 1/(1 + Math.exp(-(w dot x)))
  y match {
    case 0 =>
      1 - h(x, w)
    case 1 =>
      h(x, w)
  }
}
def likelyhood(w: DenseVector[Double]) = (0 until X.rows).map(i =>
  Pr(if(y(i) > 0.0) 1 else 0, X(i, ::).t, w)
).foldLeft(1.0)((s, p) => s * p)

// パラメータは適当な初期値で始める。
var w = DenseVector(0.0, 0.0, 0.0, 0.0)
var wNext = step(S(w), u(w), w)
val tolerance = 0.0001
var counter = 0

while(norm(w - wNext) > tolerance && counter < 100){
  w = wNext
  wNext = step(S(w), u(w), w)

  // 尤度が減少した場合はステップ幅を半分にする。
  while(likelyhood(w) > likelyhood(wNext)){
    wNext = w + ((w - wNext) /:/ 2.0)
  }

  counter += 1
}

s"counter = $counter, likelyhood(w) = ${likelyhood(w)}"
w

最後に

$step$関数の実装をみるとScalaの高い表現力を生かして数式のように直感的に書けて、Breeze素晴らしいですね。

参考リンク

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0