LoginSignup
0
0

More than 3 years have passed since last update.

機械学習で四則演算 4【積 失敗2 学習率導入】

Last updated at Posted at 2018-01-29

積2.png

前回はあまりな事に思わず取り乱してしまいました。\\
落ち着いてみたら、学習率というものがあったのを思い出しました。\\
なので\\
w[0]=
\begin{pmatrix}
1 & 1\\
1 & -1
\end{pmatrix}
で固定\\
w[1]の学習率を η=0.5 とし\\
w[1]=
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}
になるか1000回学習させてみました。
        string path = @"D:\開発\AI\" + DateTime.Now.ToString("yyyyMMdd") + "積.csv";
        double[][,] w = new double[2][,];

        int In = 2;
        //int layer = 1;
        int cell = 2;
        //int bias = 0;
        int Out = 1;
        double η = 0.5;

        Stopwatch sw = new Stopwatch();
        delegate void dldl();

        private void button1_Click(object sender, EventArgs e)
        {
            button1.Enabled = false;
            sw.Start();
            Thread th = new Thread(perceptron);
            th.Start();
        }

        private void perceptron()
        {
            //列見出し
            string row_title = ",a,b,ab,";
            for (int n1 = 0; n1 < In; n1++)
            {
                for (int n2 = 0; n2 < cell; n2++)
                {
                    row_title += "w[0][" + n1 + "_" + n2 + "],";
                }
            }
            for (int n = 0; n < cell; n++)
            {
                row_title += "h_in[" + n + "],h_out[" + n + "],";
            }
            for (int n1 = 0; n1 < cell; n1++)
            {
                for (int n2 = 0; n2 < Out; n2++)
                {
                    row_title += "w[1][" + n1 + "_" + n2 + "],";
                }
            }
            row_title += "Y,ΔE";
            using (StreamWriter sw = new StreamWriter(path, true, Encoding.Unicode))
            {
                sw.Write(row_title + Environment.NewLine);
            }

            //重み初期設定
            //w[0] = new double[In, cell];
            w[0] = new double[2, 2] { { 1, 1 }, { 1, -1 } };            
            w[1] = new double[cell, Out];
            //w[1] = new double[2, 1] { { 0.25 }, { -0.25 } };

            Random cRandom = new Random();// 0.0 以上 1.0 以下の乱数を取得
            //RNGCryptoServiceProvider rng = new RNGCryptoServiceProvider();
            //byte[] bs = new byte[sizeof(int)];
            //for (int n0 = 0; n0 < w.Length; n0++)
            //{
                for (int n1 = 0; n1 < w[1].GetLength(0); n1++)
                {
                    for (int n2 = 0; n2 < w[1].GetLength(1); n2++)
                    {
                        w[1][n1, n2] = cRandom.NextDouble();
                        //rng.GetBytes(bs);
                        //w[n0][n1, n2] = BitConverter.ToInt32(bs, 0);
                    }
                }
            //}

            //学習
            for (int p = 1; p <= 2000; p++)
            {
                //入力値設定
                double a = cRandom.NextDouble();
                double b = cRandom.NextDouble();
                //rng.GetBytes(bs);
                //double a = BitConverter.ToInt32(bs, 0);
                //rng.GetBytes(bs);
                //double b = BitConverter.ToInt32(bs, 0);

                //教師信号
                double a_mul_b = a * b;

                //順伝播
                double[] h_in = new double[cell];
                double[] h_out = new double[cell];

                for (int n = 0; n < cell; n++)
                {//入力
                    h_in[n] = a * w[0][0, n] + b * w[0][1, n];
                    h_out[n] = h_in[n] * h_in[n];
                }

                double Y = 0;
                for (int cl = 0; cl < cell; cl++)
                {//出力
                    Y += h_out[cl] * w[1][cl, 0];
                }

                //二乗誤差
                double dE = Y - a_mul_b;//計算省略のため二乗誤差微分後の値

                //記録
                string rec = p.ToString() + "," + a.ToString() + "," + b.ToString() + "," + a_mul_b.ToString() + ",";
                for (int n1 = 0; n1 < w[0].GetLength(0); n1++)
                {
                    for (int n2 = 0; n2 < w[0].GetLength(1); n2++)
                    {
                        rec += w[0][n1, n2].ToString() + ",";
                    }
                }
                for (int n = 0; n < cell; n++)
                {
                    rec += h_in[n].ToString() + "," + h_out[n].ToString() + ",";
                }
                for (int n1 = 0; n1 < w[1].GetLength(0); n1++)
                {
                    for (int n2 = 0; n2 < w[1].GetLength(1); n2++)
                    {
                        rec += w[1][n1, n2].ToString() + ",";
                    }
                }
                rec += Y.ToString() + "," + dE.ToString();
                using (StreamWriter sw = new StreamWriter(path, true, Encoding.Unicode))
                {
                    sw.Write(rec + Environment.NewLine);
                }

                //逆伝播
                //∂E/∂In                       
                //double[] dE_dI = new double[cell];
                //for (int cl = 0; cl < cell; cl++)
                //{
                //    dE_dI[cl] = 2 * h_in[cl] * w[1][cl, 0] * dE;
                //}
                //w-Δw
                //for (int cl = 0; cl < cell; cl++)
                //{
                //    w[0][0, cl] = w[0][0, cl] - η * (a * dE_dI[cl]);
                //    w[0][1, cl] = w[0][1, cl] - η * (b * dE_dI[cl]);
                //}
                for (int cl = 0; cl < cell; cl++)
                {
                    w[1][cl, 0] = w[1][cl, 0] - η * (h_out[cl] * dE);
                }
            }

            //最終記録
            string last_rec = new string(',', In + Out + 1);
            for (int n1 = 0; n1 < w[0].GetLength(0); n1++)
            {
                for (int n2 = 0; n2 < w[0].GetLength(1); n2++)
                {
                    last_rec += w[0][n1, n2].ToString() + ",";
                }
            }
            last_rec += new string(',', cell * 2);
            for (int n1 = 0; n1 < w[1].GetLength(0); n1++)
            {
                for (int n2 = 0; n2 < w[1].GetLength(1); n2++)
                {
                    last_rec += w[1][n1, n2].ToString() + ",";
                }
            }
            last_rec += new string(',', Out * 2 - 1);
            using (StreamWriter sw = new StreamWriter(path, true, Encoding.Unicode))
            {
                sw.Write(last_rec);
            }

            //読取専用
            File.SetAttributes(path, FileAttributes.ReadOnly);

            Invoke(new dldl(delegate
            {
                sw.Stop();
                label1.Text = sw.Elapsed.ToString();
                button1.Enabled = true;
            }));

        }        

積w[1]η=0.5.png

目標値\\
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}\\
初期値\\
\begin{pmatrix}
0.503390885192617\\
0.924137140588899
\end{pmatrix}\\
計算値\\
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}

意外にもあっさりうまくいきました。
私の狭い知見で本やネットを調べた限り、この学習率というのは試行錯誤して適切な値に設定する必要があるようなので
η=0.8 でもやってみました。

積w[1]η=0.8.png

目標値\\
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}\\
初期値\\
\begin{pmatrix}
0.527055652591891\\
0.315029976104866
\end{pmatrix}\\
計算値\\
\begin{pmatrix}
589595066397.405\\
3767069263729.84
\end{pmatrix}

η=0.8 ではかんばしくなかったので
η=0.1 にしてみました。

積w[1]η=0.1.png

目標値\\
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}\\
初期値\\
\begin{pmatrix}
0.131565617924354\\
0.974423807102453
\end{pmatrix}\\
計算値\\
\begin{pmatrix}
0.249653737025605\\
-0.245864964583958
\end{pmatrix}

目標値に収束しそうなのですが学習率が低すぎて定着が遅いようです。
他人事とは思えません。
2000回学習させてみました。

積2000w[0]固定w[1]η=0.1.png

目標値\\
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}\\
初期値\\
\begin{pmatrix}
0.95706960137797\\
0.942992231782056
\end{pmatrix}\\
計算値\\
\begin{pmatrix}
0.249996098305968\\
-0.249943942878657
\end{pmatrix}

「良くがんばったね」
思わず声をかけてしまいました。

学習率をもっと小刻みに、0.1刻みぐらいで全部試そうかと思いましたが飽きてきたので
w[0]の学習率η[0]=1.0
w[1]の学習率η[1]=0.5
で学習させたらどうなるかを試してみました。

ちなみに、この段階で初期値や学習させる値を統一しなきゃいけないところを
毎回乱数を発生させていたのに気がつきました。
とは言うものの、保存した値を読み出すプログラムを書くのが面倒なので手を抜きました。
こういう所で、ちゃんとする人と、しない人で人生が変わってくるのかもしれません。

積1000w[0]η[0]=1.0η[1]=0.5.png
積1000w[1]η[0]=1.0η[1]=0.5.png

目標値\\
w[0]=
\begin{pmatrix}
1 & 1\\
1 & -1
\end{pmatrix}
w[1]=
\begin{pmatrix}
0.25\\
-0.25
\end{pmatrix}\\
初期値\\
w[0]=
\begin{pmatrix}
0.365073265677818 & 0.624229858920085\\
0.725459170399913 & 0.796839401496965
\end{pmatrix}
w[1]=
\begin{pmatrix}
0.805179213082967\\
0.050726107810962
\end{pmatrix}\\
計算値\\
w[0]=
\begin{pmatrix}
0.569994619242016 & 0.580250012379678\\
0.531595328837389 & 0.546938464664657
\end{pmatrix}
w[1]=
\begin{pmatrix}
0.393130085134921\\
0.378365301427354
\end{pmatrix}

収束しそうでジタバタしてます。

以降、試行錯誤をダイジェストでご覧下さい。
積2000w[0]η[0]=1.0η[1]=0.5.png
積2000w[1]η[0]=1.0η[1]=0.5.png
積5000w[0]η[0]=1.0η[1]=0.5.png
積5000w[1]η[0]=1.0η[1]=0.5.png
積2000w[0]η[0]=0.5η[1]=0.5.png
積2000w[1]η[0]=0.5η[1]=0.5.png
積5000w[0]η[0]=0.5η[1]=0.5.png
積5000w[1]η[0]=0.5η[1]=0.5.png

次回
積 Python導入

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0