LoginSignup
0
0

More than 5 years have passed since last update.

機械学習で二次方程式の解 2

Last updated at Posted at 2017-12-03
a(x-α)(x-β)=0

a = 1 で固定し
解α,βを正の整数0~100までの組み合わせ
(0,0)
(0,1)
...
(100,99)
(100,100)
101×101=10201通りを学習させてみました。
ニューラルネットワークはとりあえずこうしてみました。
出力は恒等関数です。
図のxの使い方がややおかしいとは思いますが感じは伝わっていると思います。
二次方程式.png

学習率は1で何もしてません。

前回は基本的なミスがありました。

誤差が
double[] mE = new double[] { o_0_out - alpha, o_1_out - beta };
となっていたので
double[] mE = new double[] { alpha - o_0_out, beta - o_1_out };
に修正

誤差逆伝播が
w1から先に修正していたため
w0はw1の修正後の値で微分してしまっていて
w0から先に誤差逆伝播するよう修正しました。

なんとか形になりました。
前回はNaN(無限大)だったか
セルにこんな値が出てきてしまい
てんで話しにならなかったのですが
とりあえず、それっぽくなりました。

しかし、やってみてわかったのは
誤差逆伝播するとその入力に対しての解が出るようになってしまい
それまでの学習が全部関係ない結果になります。

本当は出力結果を全部掲載して説明したいところなんですが
データが多すぎておそらく投稿できないのと
できたとしてとても重くなりそうなので
最終的な重みの値だけ掲載しておきます。

wa_h0_0 -2309.71639140184
wa_h0_1 -2309.71639140184
wa_h0_2 -2309.71639140184
wb_h0_0 12644370.4100154
wb_h0_1 12644370.4100154
wb_h0_2 12644370.4100154
wc_h0_0 -639010033.780352
wc_h0_1 -639010033.780352
wc_h0_2 -639010033.780352
wB_h0_0 -123332.418400352
wB_h0_1 -123332.418400352
wB_h0_2 -123332.418400352

wh1_0_o0 9.40005648077475
wh1_0_o1 -100.266352010226
wh1_1_o0 9.40005648077475
wh1_1_o1 -100.266352010226
wh1_2_o0 9.40005648077475
wh1_2_o1 -100.266352010226
wB1_3_o0 100
wB1_3_o1 99


        double[,] w0 = new double[,]{
                                        { 0.5,0.5,0.5},
                                        { 0.5,0.5,0.5},
                                        { 0.5,0.5,0.5},
                                        { 0.5,0.5,0.5}};

        double[,] w1 = new double[,]{
                                        { 0.5,0.5},
                                        { 0.5,0.5},
                                        { 0.5,0.5},
                                        { 0.5,0.5}};

        private void nnw()
        {
            for (double alpha = 0; alpha <= 100; alpha++)
            {
                for (double beta = 0; beta <= 100; beta++)
                {
                    //
                    double a = 1;
                    double b = -(alpha + beta);
                    double c = alpha * beta;
                    double d = b * b - 4 * a * c;

                    //forward
                    double h0_0_in = a * w0[0, 0] + b * w0[1, 0] + c * w0[2, 0] + w0[3, 0];
                    double h0_1_in = a * w0[0, 1] + b * w0[1, 1] + c * w0[2, 1] + w0[3, 1];
                    double h0_2_in = a * w0[0, 2] + b * w0[1, 2] + c * w0[2, 2] + w0[3, 2];

                    double h0_0_out = s(h0_0_in);
                    double h0_1_out = s(h0_1_in);
                    double h0_2_out = s(h0_2_in);

                    double o_0_in = h0_0_out * w1[0, 0] + h0_1_out * w1[1, 0] + h0_2_out * w1[2, 0] + w1[3, 0];
                    double o_1_in = h0_0_out * w1[0, 1] + h0_1_out * w1[1, 1] + h0_2_out * w1[2, 1] + w1[3, 1];

                    double o_0_out = o_0_in;
                    double o_1_out = o_1_in;

                    double[] mE = new double[] { alpha - o_0_out, beta - o_1_out };

                    Invoke(new dldl(delegate
                    {
                        dataGridView1.Rows.Add(
                            a.ToString(), b.ToString(), c.ToString(), d.ToString(),
                            alpha.ToString(), beta.ToString(),
                            w0[0, 0].ToString(), w0[0, 1].ToString(), w0[0, 2].ToString(),
                            w0[1, 0].ToString(), w0[1, 1].ToString(), w0[1, 2].ToString(),
                            w0[2, 0].ToString(), w0[2, 1].ToString(), w0[2, 2].ToString(),
                            w0[3, 0].ToString(), w0[3, 1].ToString(), w0[3, 2].ToString(),
                            h0_0_in.ToString(), h0_0_out.ToString(),
                            h0_1_in.ToString(), h0_1_out.ToString(),
                            h0_2_in.ToString(), h0_2_out.ToString(),
                            w1[0, 0].ToString(), w1[0, 1].ToString(),
                            w1[1, 0].ToString(), w1[1, 1].ToString(),
                            w1[2, 0].ToString(), w1[2, 1].ToString(),
                            w1[3, 0].ToString(), w1[3, 1].ToString(),
                            o_0_in.ToString(), o_0_out.ToString(),
                            o_1_in.ToString(), o_1_out.ToString(),
                            mE[0].ToString(), mE[1].ToString()
                            );
                        Refresh();
                    }));

                    //backward
                    w0[0, 0] = w0[0, 0] - (((-mE[0] * w1[0, 0]) + (-mE[1] * w1[0, 1])) * s(-h0_0_out) * (1 - s(-h0_0_out)) * a);
                    w0[0, 1] = w0[0, 1] - (((-mE[0] * w1[0, 0]) + (-mE[1] * w1[0, 1])) * s(-h0_1_out) * (1 - s(-h0_1_out)) * a);
                    w0[0, 2] = w0[0, 2] - (((-mE[0] * w1[0, 0]) + (-mE[1] * w1[0, 1])) * s(-h0_2_out) * (1 - s(-h0_2_out)) * a);

                    w0[1, 0] = w0[1, 0] - (((-mE[0] * w1[1, 0]) + (-mE[1] * w1[1, 1])) * s(-h0_0_out) * (1 - s(-h0_0_out)) * b);
                    w0[1, 1] = w0[1, 1] - (((-mE[0] * w1[1, 0]) + (-mE[1] * w1[1, 1])) * s(-h0_1_out) * (1 - s(-h0_1_out)) * b);
                    w0[1, 2] = w0[1, 2] - (((-mE[0] * w1[1, 0]) + (-mE[1] * w1[1, 1])) * s(-h0_2_out) * (1 - s(-h0_2_out)) * b);

                    w0[2, 0] = w0[2, 0] - (((-mE[0] * w1[2, 0]) + (-mE[1] * w1[2, 1])) * s(-h0_0_out) * (1 - s(-h0_0_out)) * c);
                    w0[2, 1] = w0[2, 1] - (((-mE[0] * w1[2, 0]) + (-mE[1] * w1[2, 1])) * s(-h0_1_out) * (1 - s(-h0_1_out)) * c);
                    w0[2, 2] = w0[2, 2] - (((-mE[0] * w1[2, 0]) + (-mE[1] * w1[2, 1])) * s(-h0_2_out) * (1 - s(-h0_2_out)) * c);

                    w0[3, 0] = w0[3, 0] - (((-mE[0] * w1[3, 0]) + (-mE[1] * w1[3, 1])) * s(-h0_0_out) * (1 - s(-h0_0_out)));
                    w0[3, 1] = w0[3, 1] - (((-mE[0] * w1[3, 0]) + (-mE[1] * w1[3, 1])) * s(-h0_1_out) * (1 - s(-h0_1_out)));
                    w0[3, 2] = w0[3, 2] - (((-mE[0] * w1[3, 0]) + (-mE[1] * w1[3, 1])) * s(-h0_2_out) * (1 - s(-h0_2_out)));


                    w1[0, 0] = w1[0, 0] - (-mE[0] * h0_0_out);
                    w1[1, 0] = w1[1, 0] - (-mE[0] * h0_1_out);
                    w1[2, 0] = w1[2, 0] - (-mE[0] * h0_2_out);
                    w1[3, 0] = w1[3, 0] - (-mE[0]);

                    w1[0, 1] = w1[0, 1] - (-mE[1] * h0_0_out);
                    w1[1, 1] = w1[1, 1] - (-mE[1] * h0_1_out);
                    w1[2, 1] = w1[2, 1] - (-mE[1] * h0_2_out);
                    w1[3, 1] = w1[3, 1] - (-mE[1]);
                }
            }

            Invoke(new dldl(delegate
            {
                button1.Enabled = true;
            }));
        }

        private double s(double x)
        {
            return 1.0/(1.0 + Math.Exp(-x));
        }

        private delegate void dldl();

        private void dataGridView1_CellPainting(object sender, DataGridViewCellPaintingEventArgs e)
        {
            //列ヘッダーかどうか調べる
            if (e.ColumnIndex < 0 && e.RowIndex >= 0)
            {
                //セルを描画する
                e.Paint(e.ClipBounds, DataGridViewPaintParts.All);

                //行番号を描画する範囲を決定する
                //e.AdvancedBorderStyleやe.CellStyle.Paddingは無視しています
                Rectangle indexRect = e.CellBounds;
                indexRect.Inflate(-2, -2);
                //行番号を描画する
                TextRenderer.DrawText(e.Graphics,
                    (e.RowIndex + 1).ToString(),
                    e.CellStyle.Font,
                    indexRect,
                    e.CellStyle.ForeColor,
                    TextFormatFlags.Right | TextFormatFlags.VerticalCenter);
                //描画が完了したことを知らせる
                e.Handled = true;
            }
        }

        private void button1_Click(object sender, EventArgs e)
        {
            button1.Enabled = false;
            Thread th = new Thread(nnw);
            th.Start();
        }
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0