8
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

信州大学 kstmAdvent Calendar 2018

Day 4

SQLで最急降下する

Last updated at Posted at 2018-12-03

信州大学 kstm Advent Calendar 2018の4日目です

最近BigQueryやSQLServerに機械学習系のアルゴリズムが追加されてますね.
今回はSQLの再帰クエリを使って最急降下して遊びました.

最急降下法

関数$f(x)$について,以下の更新式で極値となるパラメータの値を求めます.

$$
x \leftarrow x - \eta \frac{\partial\mathcal{f}}{\partial x}
$$

ためしにSQLで次の関数を最小化してみましょう.

$$
y(x) = x^2 - 6x + 9
$$

$y' = 2x - 6$,学習率を$\eta=0.01$とするとこんなかんじ

with
    recursive gd(iter, x, y) as (
        select
            0, cast(0.0 as decimal(6,3)), cast(9 as decimal(6,3))
        union all
        select
            iter+1, x-0.01*(2.0*x-6.0), pow(x-0.01*(2.0*x-6.0),2)-6*(x-0.01*(2.0*x-6.0)) + 9 
        from
            gd
        where
            iter < 1000
            and y - (pow(x-0.1*(2.0*x-6.0),2)-6*(x-0.1*(2.0*x-6.0)) + 9) > 0.001
    )
select * from gd order by iter desc limit 1

結果は...

iter x y
197 2.941 0.003

できた!

線形回帰もやってみる

$y \sim \mathcal{N}(a + bx, \sigma)$のようなデータを適当に作っておきます.
ここでは$a=1, b=-3$にしました.

plot.png

誤差関数と更新式は以下.
$$ \mathcal{L} = \frac{1}{2N} \sum_{i=1}^{N}(a+bx_i-y_i)^2 $$

\begin{align}
\frac{\partial\mathcal{L}}{\partial a} &= \frac{1}{N} \sum_{i=1}^{N}(a+bx_i-y_i) \\
\frac{\partial\mathcal{L}}{\partial b} &= \frac{1}{N} \sum_{i=1}^{N}(a+bx_i-y_i)x_i
\end{align}

学習率$\eta=0.01$で,誤差の変化が$0.01$以下になったら終了するようにしました.


with
    recursive agg as (
        select
            sum(x) s_x,
            sum(y) s_y,
            sum(x*x) s_x2,
            sum(y*y) s_y2,
            sum(x*y) s_xy,
            count(*) n
        from lr
    ),
    gd(iter, a, b, loss) as (
        select
            0, cast(1.0 as decimal(10, 5)), cast(1.0 as decimal(10, 5)), cast(100.0 as decimal(10, 5))
        from lr
        union all
        select
            iter+1,
            a-0.01*(n*a+b*s_x-s_y)/n,
            b-0.01*(a*s_x+b*s_x2-s_xy)/n,
            (n*pow(a-0.01*(n*a+b*s_x-s_y)/n,2)
             +pow(b-0.01*(a*s_x+b*s_x2-s_xy),2)*s_x2
             +s_y2
             +2*(a-0.01*(n*a+b*s_x-s_y)/n)*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_x
             -2*(a-0.01*(n*a+b*s_x-s_y)/n)*s_y
             -2*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_xy)/(2*n)
        from
            gd, agg
        where
            iter < 5000
            and loss > 0.1
            and abs(
                loss-(n*pow(a-0.01*(n*a+b*s_x-s_y)/n,2)
                +pow(b-0.01*(a*s_x+b*s_x2-s_xy),2)*s_x2
                +s_y2
                +2*(a-0.01*(n*a+b*s_x-s_y)/n)*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_x
                -2*(a-0.01*(n*a+b*s_x-s_y)/n)*s_y
                -2*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_xy)/(2*n)) > 0.01
    )
select * from gd order by iter desc limit 1

結果はこちら

iter a b loss
18 0.34115 -2.84881 3.31067

plot.png

いいかんじですね.

今回の実装を実行したnotebookはこちらです.

8
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?