LoginSignup
11
12

More than 3 years have passed since last update.

ロジスティック回帰を導出から実装までみっちり途中式解説

Posted at

線形単回帰知ったボク

ふーん
最小二乗法便利じゃん

ロジスティック回帰知ったボク

ふーん・・・
???なんでそれで係数求まるの???

ということで解説しながらやっていきます

ロジスティック回帰についての疑問点5stepで。

前提条件

ある調査によって男女の身長を求めた。
身長データの平均は160であった。
2行取り出した以下のデータを使いロジスティック回帰を行ってみよ。
ロジスティック回帰で切片項は使用しないこととする。

sex height
162
158

問1:ロジスティック回帰を行うためにデータを書き換えよ

ロジスティック回帰では目的の分類クラスを1,0に変換することで、クラスを確率として扱う。
1=100%として考えることで、0~1までの数値の回帰によってクラス分類を達成する。
そのため以下のように書き換えた。
また、切片項を使用しないので、事前に得られている平均値から各データを引いておく。

sex height
1 2
0 -2

問2:1から0までの値で出力するための変換関数(シグモイド関数)を書け

ロジスティック回帰ではシグモイド関数によって1から0の間に値を変換する。
数式に苦手感を感じる人はexp(x)という記号が2.718のx乗だと考えてくれればいい。

image.png

関数として描画すると以下のような形になる。

image.png

問3:ロジスティック回帰の尤度関数から負の対数尤度関数までの流れを記述せよ

まずロジスティック回帰ではシグモイド関数に

  • 男の時(sex = 1)のデータ(height = x)を入力すると1に近くなり(大きくなり)
  • 男でない時(sex = 0)のデータでは0に近くなる(小さくなる)

という内容を満たすように、係数wをデータxにかけることで調節していくのが目的の一つである。

ここで

  • 「男の時は、男でないときは」とかプログラムじゃないんだから一つの式に表して計算したい
  • 「大きくする、とか小さくする」とかどっちかに統一しろや

って気持ちがあるのでこれを満たすような関数を考えてやる。

image.png

そもそも、シグモイド関数にxを入れた時に出てくる値yは0~1の確率値として考えて左辺をP()で表す。
(推定値なので$\hat{y}$や$\hat{sex}$とした方が適切かもしれない)

  • 右辺の式はsexが1ならば後ろの部分は0乗となり前部分だけが残る
  • sexが0ならば後ろの部分だけが残り、シグモイド関数にxを代入したときに0に近いほど1から引かれる値が大きく成功とみなされる
  • この計算を各データに行い、最大になるような係数wを求めたい
  • 最大を考える時にすべての出力を掛け算して最大になるときのwを考えよう

以上のことからwに関する関数が考えられた。
この関数を尤度関数と呼ぶ。

総乗記号を使い一般化して書くと
xに添え字iを書き忘れたが

image.png

という式で表現できる。

しかし、

  • 0~1の間の数字を何度も掛け算していくと0に極めて近い値になってしまう
  • 掛け算は計算機君でもめんどくさがる

ので、簡単な足し算にしたいよね。
ってことで掛け算の性質を足し算の性質に変換するlog君に登場してもらう。

すると、総乗記号は総和記号に書き換わり以下のように書き直せる。

image.png

尤度関数に対数logを使ったので、これを対数尤度関数と呼ぶ

まだまだ変形すんで。
そもそも尤度関数も対数尤度関数も最大にしてくれるような係数wを求めるのが目的だった。
最適化では最小化問題に書き換えるという慣習がある。
(きっと計算数値が小さくて済むからだと考えているが違ったらゴメンなさい。教えて。)

対数のグラフlog(x)を-log(x)に書き直してみると逆転しているだけだってことがわかると思う。

image.png

なので全体にマイナスを掛け算してやると

image.png

となり、これを負の対数尤度関数と呼ぶ。
よく交差エントロピーなどとも呼ばれる

問4:負の対数尤度関数を使ってどうやって係数wを決めていくのか

負の対数尤度関数と、

パターン1.男のデータの時の尤度関数の値の関係は

image.png

パターン2.男でないときの関係

image.png

となる。つまり、負の対数尤度関数の値が小さくなる時は、関数の傾きが小さくなっている時である。
よって関数の傾きを小さく
というか、0にするようにwを決めていけばいいことがわかる。

問5:じゃあ傾き求めてみて

5-1全体を微分する準備

傾きを求めるために負の対数尤度関数の導関数を求めていく。

合成関数の微分は複雑なのでみっちりと順を追って記述していく。

負の対数尤度関数を係数wに関する関数J(w)として考える
あと、sexって書くのが面倒なのでsとして省略した。
微分は各項に対して行うものなので、前の項と後の項を微分する。

image.png

一般規則の復習

ここで対数logの中に関数f(x)が入っている時の微分は、合成関数の基本的な規則より

image.png

対数の微分は逆数となり、括弧内の関数の微分を後から掛け算することで求まると思い出しておく。

image.png

つまり計算するとこうなる。

5-2対数の微分後のちょっとした微分処理と式整理

後ろの項の

image.png

は計算すると

image.png

となり、前の項と共通になるのでくくりだせる

image.png

あと残す微分をやっつけていく。

5-3シグモイド関数の微分

wxが計算をややこしくしそうなので、zに置き換えてから、合成関数として求めてやる。

image.png

zの微分は求めるの簡単

image.png

シグモイド関数の微分も合成関数の微分として順番に解いていく

image.png

なんでこんなこと考えつくんだろって思うけど式変形していくと、

image.png

となる。
最後にはzに関するシグモイド関数が出てきた。

wxをzと置いていたことを思い出して、代入してやると、

image.png

zに関するシグモイド関数の微分を計算し終えることができた。

先ほどのZと置いて合成関数の微分を行った結果を合わせると

image.png

となる。

5-2に5-3の計算結果をあわせる

さて、少し5-2戻って式変形をする
純粋に式を変形しているだけ。

image.png

どうしてこんな式変形が思いつくんや
すると、角括弧内とシグモイド関数の微分によって出てきた部分が約分できる。

あとは純粋に綺麗にしていくと

image.png

すっごい綺麗にまとまった。

5-4係数を求め、もともとのロジスティック回帰と、学習後のロジスティック回帰を比較せよ

もともとのシグモイド関数にデータを入力した場合

image.png

これでも0.5よりも小さいか大きいかで判別させたら十分機能はしますが、
今回はこれをより0と1に近づけていく学習なので、
勾配降下法によって学習させていきます。

更新式は

新しい係数 <- 係数 - 学習率*(シグモイド関数(-w*x) - ラベルy)x

です。
勾配降下法については過去記事で紹介しております。

初期値を0として係数を学習させていくと

image.png

係数=5あたりでゆっくりになってきました。
どれだけ大きくしていってもうまく分けられるのできっと上昇し続けるのではないかと思います。

そこで得られた係数を使ってシグモイド関数を書いてみます。

image.png

S字の曲線はより急な変化になり、1や0にうまくデータを振れていることがわかりました。

ということで、係数を求め、シグモイド関数の比較ができました。

ちなみにRのglm関数で求めた場合は係数=11になりました。

image.png

以上

スッキリした。

以下コード

x<-seq(-50,50,0.1)

sigmoid <-function(x){
  A<-1
  B<-1+exp(-x)
  return(A/B)
}

y<-sigmoid(x)
plot(x,y,type="l",main="sigmoid function")

x<-seq(0.1,20,0.1)
y1 <-log(x)
y2<- -log(x)
plot(x,y1,type="l",xlim=c(0,20),ylim=c(-5,5),main="red = -log(x)")
points(x,y2,type="l",col="red")

plot(y,-log(y),type="l",xlab = "sigmoid(-wx)",ylab="cross ent")

x<-seq(-50,50,0.1)
y3<-1-sigmoid(x)
y<-sigmoid(x)

plot(y,-log(y3),type="l",xlab = "sigmoid(-wx)",ylab="cross ent")

#################

x<-c(2,-2)
y<-c(1,0)

plot(x=x,y=y)
data<-data.frame(x=x,y=y)

w_0 <- 0
lr <- 0.5
w_list<-NULL
ite<-10000
lis<-seq(1,ite,length.out = 100)
for(i in 1:ite){
  rand<-sample(2)[1]
  data_x<-data$x[rand]
  label<-data$y[rand]
  w_new <- w_0 - lr * (sigmoid(x=w_0*data_x) - label)*data_x
  w_0 <- w_new
  if(sum(lis%in%i)==1){
    w_list<-c(w_list,w_new)
  }
}
w_0
plot(1:length(w_list),w_list)
#dev.off()

dd<-seq(-3,3,0.1)
plot(dd,sigmoid(dd))
abline(v=data$x[1])
abline(v=data$x[2],col="red")

sigmoid2 <- function(x,w){
  A<-1
  B<-1+exp(-x*w)
  return(A/B)
}

dd<-seq(-3,3,0.1)
plot(dd,sigmoid2(dd,w=w_0),type="l")
abline(v=data$x[1])
abline(v=data$x[2],col="red")

res = glm(y ~ .-1,data, family=binomial)
res_sum<-summary(res)
fit = fitted(res)
plot(data$x, fit,col="red")
points(dd,sigmoid2(dd,w=res_sum$coefficients[1]),type="l")
11
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
12