Udemyの【キカガク流】人工知能・機械学習 脱ブラックボックス講座 - 初級編 -という講座にて,単回帰分析について学んだので忘れない内に忘備録を残しておきます
ちなみにQiitaにはいつもお世話になっていても,自分で記事を書くということは恐れ多くて一度もしていなかったのですが今回初投稿します.
単回帰分析とは
単回帰分析という大層な名前がついていますが,要は中学校のときに習った $y = ax + b$というやつです で,この1次関数に未知のデータを入れると,予測してくれるって感じです
なんだそんだけかよって感じですがそんだけです ちなみに変数が1つの場合の回帰分析だから**「単回帰」**なんです じゃあ変数が2つ以上になったらどうなるかというと,「重回帰」分析になります 単純ですね
ちなみに回帰という言葉,なぜ1次関数の式に回帰って言葉がついているかって思いませんか
僕はずっと謎だったのですが,以前統計学の本を読んでいると回帰について書いてありました 内容は,身長の低い人親から生まれてくる子供はもっと低くなるわけではなく,少し高い子が生まれてくる「平均への回帰」という意味が語源だそうです
確かに言われてみれば身長低い親からさらに低いor 同じ子供がずっと生まれたら無限に小さくなりますもんね いやーよく出来てますね あ,ふーんって感じですよね 次に行きます
データの中心化について
単回帰分析では**データの中心化(センタリング)**ということをします
というのは中心化というくらいなんで,バラバラのデータを中心に持って来る必要があります じゃあどうするのかというと,平均を使うわけです
つまりデータ全部に対して,平均を引きます そうすることによって,データが中心(グラフでいう原点)に寄っていきます
例えば日本人男性の身長のデータがあった時に,大体170センチ前後なので全部のデータから170を引けばまあ大体プラス・マイナス5くらいには収まるだろうって感じしますよね?
で,中心化は何がいいのかって話なんですが,定数項を消せるということです
どういうことかというと, $ \hat{y} = ax + b$ が $ \hat{y} = ax $ になるってことなんですよ(ちなみにこのyの上についている記号はハットという記号で,統計学では予測値に対して使う記号です)
いや,そんだけかよって思いますが,一つ文字を減らせたのはでかいです で,なぜそうなるかというと,中心に寄ったので原点を通るとみなせるからです
一応書くと,原点を通るグラフは切片$b$が0なので$ y = ax $になりますよね
評価関数について
評価関数とは損失関数とも言うらしいですが,要は実測値と予測値の差のことです
当たり前ですが,評価関数は小さければ小さいほどいいです というのはこの誤差が小さいということは回帰式のモデルはより正確ということを意味するからです
で,どうすんのかっていうと誤差を小さくするようパラメータを調整するってことです 当たり前すぎますね
具体的には $ (y - \hat{y})^2 $というyの実測値とyの予測値の二乗誤差を計算して,その後に微分をします ちなみになんで二乗なのかというと,二次関数になる=微分可能 だからです
ちなみに絶対値を取ると,カクカクなグラフができ,微分不可能です
二乗誤差の導出
評価関数 $ L = \sum_{n=0}^N (y_n - \hat{y_n}) ^2 $
評価関数は出来たんで,あとはどうするかっていうと先程と言った通り微分するだけです その後,最小化するので微分して0にします
つまり$\frac {\partial}{\partial a } (L) = 0 $ということをしてあげればいいのです
要は評価関数$L$を$a$で偏微分して,$0$にするってことですね
また$\hat{y} = ax$だったので$\hat{y_n} = ax_n $になります
これを代入すると$ L = \sum_{n=0}^N (y_n - ax_n) ^2 $ になり
展開すると$ \sum_{n=0}^N (y_n^2 -2y_nax_n + a^2x_n^2)$ になります
そしてこれをaで偏微分するので
$ \displaystyle \frac{\partial}{\partial a}\sum_{n=0}^N (y_n^2 -2y_nax_n + a^2x_n^2)$
$= \displaystyle 2\sum_{n=0}^N(y_nx_n) - 2\sum_{n=0}^N(x_n^2)a$
これが0になるので変形するとにこうなります
$$\displaystyle\sum_{n=0}^N(y_nx_n) = a\sum_{n=0}^N(x_n^2)$$
そして$a$を求めるように変形すると
$$ a = \displaystyle\frac{\sum_{n=0}^Ny_nx_n}{\sum_{n=0}^Nx_n^2} $$
ということでaの式が求まりました わーい
ここで気をつけてほしいのは,ここでいう$x_n$と$y_n$は中心化をしたものということです なので実際に求めるにはもととなるデータに中心化,つまり平均を引くという作業をする必要があります そして上の式を使ってパラメータ$a$を求めるという流れになります
例えば$x = \{1, 2, 3\}$のとき$y=\{2, 4, 6\}$という数値を学習させ,$x =1.5$のデータのときの$y$は何であるか予測するという簡単な問題があったとします
この問題を解くにはまず$x,y$の平均を求め中心化を行います
まずxの平均$\bar{x} = 2$,yの平均$\bar{y} = 4$となります
そしてxを中心化した$x_c = \{-1,0,1\}$,yを中心化した$y_c = \{-2,0,2\}$です あとは上の式に代入すると$a = 2$となります
そして$\hat{y}=2x$とわかったので,$x=1.5$を代入すると$\hat{y}=3$が求まりました
という流れです これはぶっちゃけただの比例の式なんでメリット全く無いですが,実際には大量のバラバラなデータを使うんで中心化が大事になってきます
さいごに
Qiitaに初投稿ということやLaTexを使い慣れていないということもあって,お見苦しい点があったかもしれませんが,誰かの参考になったなら幸いです
また上に書いたUdemyの講座は,数学のわからない僕のような人間でも理解することが出来るほどわかりやすいのでめちゃくちゃオススメです
参考にしたサイト
両方めちゃくちゃわかりやすくて助かりました
Qiitaの数式チートシート
Qiita Markdown 書き方 まとめ