LoginSignup
0
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦54 ~勾配降下法と誤差逆伝播法~

Last updated at Posted at 2022-05-03

前回

今回の目標

前回作成したネットワークのパラメータ更新を実装する。

ここから本編

パラメータ更新方法

誤差逆伝播法によるパラメータの更新方法は44 ~勾配降下法と誤差逆伝播法~でもまとめましたが、難しい内容なのでもう一度まとめます。
また上記事の、

下付き文字ですが、「pre」が前の層、「next」が次の層、「now」もしくは何も書いていなければ現在着目している層、「last」が最終層(出力層)を指します。下付き文字がコンマで区切られて二つある場合、二つ目はその層の何番目のノードのものであるかを指します。
また、「old」が更新前のパラメータ、「new」が更新後のパラメータです。
一般的に他の文献では層の数が上付き文字であらわされることが多いですが、累乗と区別がつかずわかりにくいのでここでは用いません。

というルールはここでも適用することにします。

全結合層

勾配降下法

最急降下法とも呼ばれる計算方法。誤差を$E$、重みを$w$、学習率を$\eta$とおくと

w_{new}=w_{old}-\eta \frac{\partial E}{\partial w_{old}}

と表せられます。これが重みの更新式であり、誤差逆伝播法とか畳み込みニューラルネットワークのフィルタ更新とかのすべての基本になります。
ここで、$w_{old}$は今まで使用していた重みですから当然既知です。$\eta$も自分で設定します。0.01ぐらいの値が使われることが多いです。
つまり勾配降下法の式で未知数となるのは右辺第二項の分数部分です。
ここで、連鎖律を用いると

\frac{\partial E}{\partial w_{old,n}}=\frac{\partial E}{\partial x_{now,n}}\frac{\partial x_{now,n}}{\partial w_{old,n}}

と表せられます。ここで$x_{now,n}$は、現在着目している層の、ノード番号$n$のノードでの線形変換後の値です。そのため、前の層の出力を$a_{pre}$とすると、

\begin{aligned}
x_{now,n}&=a_{pre}w_{old,n}\\
a_{pre}&=\frac{x_{now,n}}{w_{old,n}}\\
&=\frac{\partial x_{now,n}}{\partial w_{old,n}}
\end{aligned}

がいえます(実際は行列計算になります)。今着目しているノードの重みに、前の層の出力をかけたものが今着目しているノードの線形変換後の値になりますから上の二行は成り立ちます。また、これは線形変換ですから、$w$で割るのも微分するのも同じになります。
これを勾配降下法の式に代入します。

\begin{aligned}
w_{new,n}&=w_{old,n}-\eta \frac{\partial E}{\partial w_{old,n}}\\
&=w_{old,n}-\eta \frac{\partial E}{\partial x_{now,n}}a_{pre}
\end{aligned}

これで未知数は$\frac{\partial E}{\partial x_{now,n}}$だけになりました。これを$\delta$とおくと、

$$
w_{new,n}=w_{old,n}-\eta \delta_n a_{pre}
$$

式が簡単になります。あとは$\delta$さえ求められれば、重みが更新できるようになります。
(重みはノードごとに違いますから、$\delta$もノードごとに違う値をとります)

最終層のデルタ

$\delta$の求め方は、最終層とそれ以外で異なります。まず、最終層での求め方について記します。
まず、$\delta$のもともとの定義は、ノード番号を$n$をすると

\begin{aligned}
\delta_n&=\frac{\partial E}{\partial x_{now,n}}\\
&=\frac{\partial E}{\partial x_{last,n}}
\end{aligned}

でした。連鎖律から、

$$
\delta_n=\frac{\partial E}{\partial a_{last,n}}\frac{\partial a_{last,n}}{\partial x_{last,n}}
$$

がいえます。ここで、右辺一つ目の分数を見ると、

\begin{aligned}
\frac{\partial E}{\partial a_{last,n}}&=\frac{\partial E(a_{last,n})}{\partial a_{last,n}}\\ \nonumber
&=E'(a_{last,n})
\end{aligned}

となります。誤差関数の微分$E'(a)$は、たとえば平均二乗誤差なら、正解を$t$、データ数を$N$とおくと

\begin{aligned}
E(a)&=\frac{(a-t)^2}{N}\\
E'(a)&=\frac{2(a-t)}{N}
\end{aligned}

のようになります。

また、連鎖律の式の右辺二つ目の分数は、活性化関数を$f$、最終層のノード数を$m$とすると

\begin{aligned}
\frac{\partial a_{last,n}}{\partial x_{last,n}}&=\frac{\partial f(x_{last})}{\partial x_{last,n}}\\
&=f'(x_{last,n})
\end{aligned}

となります。この層の出力$a_{last}$はこの層での線形変換後の値$x_{last}$に活性化関数をかけたものですから、当然こうなります。活性化関数の微分$f'$は、たとえばReLu関数なら

\begin{align}
% f(x)&=\left\{\begin{array}{l, l}
% 0&if\ x < 0\\
% x&if\ x >= x
% \end{array}\right.\\
% f'(x)&=\left\{\begin{array}{l,l}
% 0&if\ x < 0\\
% 1&if\ x >= 0
% \end{array}\right.

f(x)&=\left\{\begin{aligned}
    0&\ if\ x < 0\\
    x&\ if\ x >= x
\end{aligned}\right. \nonumber \\
f'(x)&=\left\{\begin{aligned}
    0&\ if\ x < 0\\
    1&\ if\ x >= x
\end{aligned}\right. \nonumber
\end{align}

となります。
右辺の一つ目・二つ目の分数をまとめると、最終層では、

\begin{aligned}
\delta_n&=\frac{\partial E}{\partial a_{last,n}}\frac{\partial a_{last,n}}{\partial x_{last,n}}\\
&=E'(a_{last,n})f'(x_{last,n})
\end{aligned}

となります。これを勾配降下法の式に代入すると、

\begin{aligned}
w_{new}&=w_{old}-\eta \delta_n a_{pre}\\
&=w_{old}-\eta E'(a_{last,n})f'(x_{last,n})a_{pre}\\
\end{aligned}

となります。全て既知ですので、これで重みが更新できるようになります。

では、具体的に値を当てはめてみましょう。
今、前の層のノード数が3、最終層のノード数が2でデータ数が5とします。
入力行列(前の層の出力)、重み行列、線形変換後の行列は以下のようになります。

\begin{align}
\mathbf A&=\left[\begin{array}{ccc}
a_{1,1}&a_{1,2}&a_{1,3}\\
a_{2,1}&a_{2,2}&a_{2,3}\\
a_{3,1}&a_{3,2}&a_{3,3}\\
a_{4,1}&a_{4,2}&a_{4,3}\\
a_{5,1}&a_{5,2}&a_{5,3}
\end{array}\right]\nonumber \\
\mathbf W&=\left[\begin{array}{cc}
w_{1,1}&w_{2,1}\\
w_{1,2}&w_{2,2}\\
w_{1,3}&w_{2,3}\\
b_1&b_2
\end{array}\right]\nonumber \\
\mathbf X&=\left[\begin{array}{cc}
a_{1,1}w_{1,1}+a_{1,2}w_{1,2}+a_{1,3}w_{1,3}+b_1&a_{1,1}w_{2,1}+a_{1,2}w_{2,2}+a_{1,3}w_{2,3}+b_2\\
a_{2,1}w_{1,1}+a_{2,2}w_{1,2}+a_{2,3}w_{1,3}+b_1&a_{2,1}w_{2,1}+a_{2,2}w_{2,2}+a_{2,3}w_{2,3}+b_2\\
a_{3,1}w_{1,1}+a_{3,2}w_{1,2}+a_{3,3}w_{1,3}+b_1&a_{3,1}w_{2,1}+a_{3,2}w_{2,2}+a_{3,3}w_{2,3}+b_2\\
a_{4,1}w_{1,1}+a_{4,2}w_{1,2}+a_{4,3}w_{1,3}+b_1&a_{4,1}w_{2,1}+a_{4,2}w_{2,2}+a_{4,3}w_{2,3}+b_2\\
a_{5,1}w_{1,1}+a_{5,2}w_{1,2}+a_{5,3}w_{1,3}+b_1&a_{5,1}w_{2,1}+a_{5,2}w_{2,2}+a_{5,3}w_{2,3}+b_2
\end{array}\right]\nonumber
\end{align}

ここで、$m$行$n$列の行列を内容問わず$M_{m,n}$と表すことにします。
すると、上の行列は以下のように表現できます。

\begin{align}
\mathbf A &=M_{5,3} \nonumber \\
\mathbf W &=M_{4,2} \nonumber \\
\mathbf X &=M_{5,2} \nonumber
\end{align}

最終章の$\delta$は、$\delta_n = E'(a_{last,n})f'(x_{last,n})$でした。
これを計算すると、

\begin{align}
\delta_n&=M_{5,1}M_{5,1}\nonumber \\
&=M_{1,5}M_{5,1}\ \ \ \ (転置しています)\nonumber \\
&=M_{1,1}\nonumber
\end{align}

となり、スカラになります。
ここで、最終層が持つすべての$\delta_n$を行列としてまとめて$\boldsymbol \delta$とすると、

\begin{align}
\boldsymbol \delta &=M_{1,2}\nonumber
\end{align}

になります。
$\boldsymbol\delta$はデータ数行・ノード数列の行列になることが分かりました。
次に、重み更新の式は$w_{new}=w_{old,n}-\eta\delta_{n}a_{pre}$でした。
計算しやすさのため前の層の出力にバイアスを加え、また、平均します。これで前の層の出力は$M_{1,4}$となりますから、

\begin{align}
\mathbf W_{new} &= \mathbf W_{old}-\eta\boldsymbol\delta\mathbf A \nonumber \\
&= M_{4,2}-\eta M_{1,2}M_{1,4} \nonumber \\
&= M_{4,2}-\eta M_{2,1}M_{1,4}\ \ \ \ (転置しています) \nonumber \\
&= M_{4,2}-\eta M_{2,4} \nonumber \\
&= M_{4,2} \nonumber
\end{align}

こう計算できます。
重み更新ですから行と列の数がそれぞれ上のようになるのは納得できますね。

最終層以外のデルタ

次に、最終層以外を見ていきます。$\delta$の定義より、連鎖律を用いると

\begin{aligned}
\delta_n&=\frac{\partial E}{\partial x_{now,n}}\\
&=\frac{\partial E}{\partial x_{next}}\frac{\partial x_{next}}{\partial x_{now,n}}\\
&=\sum\limits_{n=1}^m \left\{\frac{\partial E}{\partial x_{next,n}}\frac{\partial x_{next,n}}{\partial x_{now,n}}\right\}\\
&=\sum\limits_{n=1}^m \left\{\frac{\partial E}{\partial x_{next,n}}\frac{\partial x_{next,n}}{\partial a_{now}}\frac{\partial a_{now}}{\partial x_{now,n}}\right\}\\
&=\sum\limits_{n=1}^m \left\{\frac{\partial E}{\partial x_{next,n}}\frac{\partial x_{next,n}}{\partial a_{now}}\right\}\frac{\partial a_{now}}{\partial x_{now,n}}
\end{aligned}

がいえます。この右辺の分数を一つずつ見ていきます。
まず一つ目。

\begin{aligned}
\frac{\partial E}{\partial x_{next,n}}=\delta_{next,n}
\end{aligned}

これは$\delta$の定義から言えますね。次の層の$\delta$は、誤差逆伝播法により既に求めているはずですから未知数ではありません。
次に二つ目。

\begin{aligned}
\frac{\partial x_{next,n}}{\partial a_{now}}&=w_{next,n}
\end{aligned}

少しわかりづらいですが、$a_{now}w_{next,n}=x_{next,n}$と書けば納得できるでしょうか。線形変換ですから、割り算も微分も同義です。
最後に三つ目。

\begin{aligned}
\frac{\partial a_{now}}{\partial x_{now,n}}=f'(x_{now,n})
\end{aligned}

これは最終層での計算でも登場しましたので説明は省きます。
そして、これらの計算により、

\begin{aligned}
\delta_n&=\sum\limits_{n=1}^m \left\{\frac{\partial E}{\partial x_{next,n}}\frac{\partial x_{next,n}}{\partial a_{now}}\right\}\frac{\partial a_{now}}{\partial x_{now,n}}\\
&=\sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n})
\end{aligned}

となります。これで未知数はすべてなくなりましたから、勾配降下法の式に代入すると、

\begin{aligned}
w_{new,n}&=w_{old,n}-\eta \delta_n a_{pre}\\
&=w_{old,n}-\eta \sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n}) a_{pre}
\end{aligned}

全結合層まとめ

最終層

\begin{aligned}
\delta_n&=E'(a_{last})f'(x_{last,n})\\
w_{new}&=w_{old}-\eta E'(a_{last})f'(x_{last,n})a_{pre}
\end{aligned}

最終層以外

\begin{aligned}
\delta_n&=\sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n})\\
w_{new}&=w_{old,n}-\eta \sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n}) a_{pre}
\end{aligned}

畳み込み層

畳み込み層については、間違っているところがあるかもしれません。

畳み込み層が最終層になることはないでしょうから、後ろに何かしらの層(別の畳み込み層層またはプーリング層)があるという前提で考えます。
ここで、入力される行列の大きさをCxIxJとします(RGBのフルHDなら、3x1920x1080)。
また、重み行列のサイズをPxQ、カーネル数をKとし、k番目のcチャネルに対応するp行q列の重みを$w_{k,c,p,q}$とおきます。ストライドを$S$とおきます。

\begin{aligned}
w_{new,n}&=w_{old,n}-\eta \frac{\partial E}{\partial w_{old,n}}\\
&=w_{old,n}-\eta \delta_n a_{pre}
\end{aligned}

上に示した勾配降下法の式が、全結合層と同様に基本になります。
ここでも、$\delta$を求めることが主目的になります。

畳み込み層用に書き直すと、以下のようになります。

\begin{aligned}
w_{new,k,c,p,q}&=w_{old,k,c,p,q}-\eta \frac{\partial E}{\partial w_{old,k,c,p,q}}\\
&=w_{old,k,c,p,q}-\eta \delta_{c,i,j} a_{pre}
\end{aligned}

バイアスについては以下の通り。

\begin{aligned}
b_{new,k}&=b_{old,k}-\eta \frac{\partial E}{\partial b_{old,k}}\\
&=w_{old,k}-\eta \delta_{c,i,j} a_{pre}
\end{aligned}

ここで、右辺第二項の分数部分を見てみると以下のように表せられます。

\begin{align}
\frac{\partial E}{\partial w_{old,k,c,p,q}}&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J \frac{\partial E}{\partial a_{now,k,i,j}}\frac{\partial a_{now,k,i,j}}{\partial w_{old,k,c,p,q}}\nonumber
\end{align}

ここで、

\begin{align}
\delta_{now}&=\frac{\partial E}{\partial x_{now}}\nonumber \\
&=\frac{\partial E}{\partial a_{now}}\frac{\partial a_{now}}{\partial x_{now}}\nonumber \\
\frac{\partial E}{\partial a_{now}}&=\delta\frac{\partial x_{now}}{\partial a_{now}}\nonumber
\end{align}

ですから、上式は

\begin{align}
\frac{\partial E}{\partial w_{old,k,c,p,q}}&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J \frac{\partial E}{\partial a_{now,k,i,j}}\frac{\partial a_{now,k,i,j}}{\partial w_{old,k,c,p,q}}\nonumber\\
&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J \delta_{now,k,i,j}\frac{\partial x_{k,i,j}}{\partial a_{now,k,i,j}}\frac{\partial a_{now,k,i,j}}{\partial w_{old,k,c,p,q}}\nonumber\\
&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J \delta_{now,k,i,j}\frac{\partial x_{k,i,j}}{\partial w_{old,k,c,p,q}}\nonumber\\
&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J\delta_{now,k,i,j}a_{pre,k,Si+p,Sj+q}\nonumber
\end{align}

と変形できます。
最後の変換は、全結合層で説明したものと同じ理屈です。$a_{pre}w+b=x$ですから、当然$\frac{\partial x}{\partial w}=a_{pre}$になります。
バイアスについては以下のようになります。

\begin{align}
\frac{\partial E}{\partial b_{old,k}}&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J \frac{\partial E}{\partial a_{now,k,i,j}}\frac{\partial a_{now,k,i,j}}{\partial b_{old,k}}\nonumber\\
&=\sum\limits_{i=1}^I\sum\limits_{j=1}^J\delta_{k,i,j}\nonumber
\end{align}

$\frac{\partial x}{\partial b}=1$ですから、上の式は納得できるはずです。

ここで、$\delta$は以下のように表せられます。

\begin{align}
\delta_{now,k,i,j}&=\frac{\partial E}{\partial x_{now,k,p,q}}\nonumber\\
&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\frac{\partial E}{\partial a_{next,k,i-p,j-q}}\frac{\partial a_{next,k,i-p,j-q}}{\partial a_{now,k,i,j}}\nonumber\\
\end{align}

ここで、$\frac{\partial E}{\partial a_{now}}=\delta\frac{\partial x_{now}}{\partial a_{now}}$より、

\begin{align}
\delta_{now,k,i,j}&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\delta_{next,k,i,j}\frac{\partial x_{next,k,i-p,j-q}}{\partial a_{next,k,i-p,j-q}}\frac{\partial a_{next,k,i-p,j-q}}{\partial a_{now,k,i,j}}\nonumber
\end{align}

がいえます。また、

\begin{align}
a_{next,k,i-p,j-q}&=f(a_{now,k,i,j}w_{next,c,k,p,q}+b_{next,k})\nonumber\\
\frac{\partial a_{next,k,i-p,j-q}}{\partial a_{now,k,i,j}}&=w_{next,c,k,p,q}f'(a_{now,k,i,j}w_{next,c,k,p,q}+b_{next,k})\nonumber
\end{align}

となりますから、

\begin{align}
\delta_{now,k,i,j}&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\delta_{next,k,i,j}\frac{\partial x_{next,k,i-p,j-q}}{\partial a_{next,k,i-p,j-q}}\frac{\partial a_{next,k,i-p,j-q}}{\partial a_{now,k,i,j}}\nonumber\\
&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\delta_{next,k,i,j}w_{next,c,k,p,q}f'(a_{now,k,i,j}w_{next,c,k,p,q}+b_{next,k})\nonumber
\end{align}

と計算できます。

畳み込み層まとめ

$\frac{\partial E}{\partial b}$を$\Delta b$、$\frac{\partial E}{\partial w}$を$\Delta w$と表しています。

バイアス

\begin{align}
b_{new,k}&=b_{old,k}-\eta\Delta b_k \nonumber \\
\Delta b_k&=\sum\limits_{i=1}^{\lfloor (I-P+1)/S \rfloor}\sum\limits_{j=1}^{\lfloor (J-Q+1)/S \rfloor}\delta_{next,k,i,j}f'(x_{now,k,i,j}+b_k) \nonumber \\
\end{align}

重み

\begin{align}
w_{new,c,p,q}&=w_{old,c,p,q}-\eta\Delta w_{old,c,p,q} \nonumber \\
\Delta w_{old,c,p,q}&=\sum\limits_{c=1}^C\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\left[\sum\limits_{i=1}^{\lfloor (I-P+1)/S \rfloor}\sum\limits_{j=1}^{\lfloor (J-Q+1)/S \rfloor}\left\{\delta_{next,c,i,j}f'(x_{now,c,i,j}+b)\right\}a_{pre,c,i+p,j+q}\right]\nonumber
\end{align}

式が複雑になるので省略していますが、上の計算をカーネル数繰り返します。
また、$\delta$は以下の式で表せられます。

\begin{align}
\delta_{c,i,j}&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\left\{\delta_{next,k,i-P-p+1,j-Q-q+1}f'(x_{k,i-P-p+1,j-Q-q+1}+b_k)w_{c,k,p,q}\right\} \nonumber
\end{align}

プーリング層

プーリング層には更新するようなパラメータはありません。

更新プログラム

長らくお待たせいたしました。これから、勾配降下法を誤差逆伝播法で実装していきます。
まず基本方針として、MyNet2はMyNetと異なり、全結合層のみで構成されているわけではありません。そのため、全ての層で同じ手法を使うことができません。なので、各層専用のbackメソッドを用意し、層の種類に合わせてそれらを呼び出す方針にしました。
また、全ての最適化関数クラスはOptimizerクラスを継承する基本方針は同じですが、勾配降下法以外の最適化関数クラスは確率的勾配降下法クラスを継承するように変更しました。理由は、勾配降下法以外の最適化クラスでは同じfitメソッドを使うためです。

まず全ての最適化関数クラスの親クラスであるOptimzierクラスを作成しました。今回は使用しませんが身にバッチ作成メソッドを作成しました。

Optimizer.java
package org.MyNet2.optimizer;

import java.util.ArrayList;
import java.util.Random;

import org.MyNet2.network.*;
import org.MyNet2.lossFunc.*;
import org.MyNet2.layer.*;
import org.MyNet2.*;

/**
 * Class for optimizer.
 */
public class Optimizer {
    public Random rand;
    /** Optimizing network */
    public Network net;
    /** Loss function of this network. */
    public LossFunction lossFunc;
    /** Length of this network. */
    public int layersLength;
    /** Learning rate. */
    public double eta = 0.01;

    /**
     * Constructor for this class.
     */
    protected Optimizer(){
        ;
    }

    /**
     * Make data for mini batch learning.
     * @param x Input data.
     * @param t Answer.
     * @param batchSize Number of batch size.
     * @param rand Random instance.
     * @return Splited input data and answer.
     */
    protected Matrix[][] makeMiniBatch(Matrix x, Matrix t, int batchSize, Random rand){
        int rtnSize = (int)(x.row / batchSize) + 1;
        int num;
        int i;
        ArrayList<Integer> order = new ArrayList<Integer>(rtnSize);
        ArrayList<Integer> check = new ArrayList<Integer>(rtnSize);

        for (i = 0; i < x.row; i++){
            check.add(i);
        }
        for (i = 0; i < x.row; i++){
            num = rand.nextInt(x.row - order.size());
            order.add(check.get(num));
            check.remove(num);
        }
        for (; i < rtnSize*batchSize; i++){
            order.add(rand.nextInt(x.row));
        }

        Matrix x_ = x.vsort(order);
        Matrix t_ = t.vsort(order);
        return new Matrix[][]{x_.vsplit(rtnSize), t_.vsplit(rtnSize)};
    }

    /**
     * Doing back propagation.
     * @param x input matrix.
     * @param y Result of forward propagation.
     * @param t Answer.
     */
    protected void back(Matrix x, Matrix y, Matrix t){
        ;
    }
}

全結合層

どんなネットワークであれ最終層は全結合層になるため、まず最終層専用の更新メソッドで最終層の重みを更新し、その後他の層の重みを更新していきます。

コンストラクタなど

GD.java
package org.MyNet2.optimizer;

import java.util.ArrayList;
import java.util.Random;
import java.io.PrintWriter;
import java.io.IOException;

import org.MyNet2.network.*;
import org.MyNet2.lossFunc.*;
import org.MyNet2.layer.*;
import org.MyNet2.*;

/**
 * Class for gradient descent.
 */
public class GD extends Optimizer {
    /**
     * Constructor fot this class.
     * @param net Optimizing network.
     * @param f Loss function of this network.
     */
    public GD(Network net, LossFunction f){
        this.net = net;
        this.lossFunc = f;
        this.layersLength = net.layers.length;
    }

    /**
     * Constructor fot this class.
     * @param net Optimizing network.
     * @param f Loss function of this network.
     * @param eta Learning rate.
     */
    public GD(Network net, LossFunction f, double eta){
        this.net = net;
        this.lossFunc = f;
        this.layersLength = net.layers.length;
        this.eta = eta;
    }

誤差逆伝播メソッド

まず最終層の重みを更新し、次に各層の重みを層の種類ごとに手法を変えながら更新します。

GD.java
    /**
     * Doing back propagation.
     * @param x input matrix.
     * @param y Result of forward propagation.
     * @param t Answer.
     */
    protected void back(Matrix x, Matrix y, Matrix t){
        this.backLastLayer(x, y, t);

        for (int i = this.layersLength-2; i >= 0; i--){
            switch (this.net.layers[i].name){
                case "Dense":
                    this.backDense(
                        i,
                        i == 0 ? x.meanCol().appendCol(1.) : this.net.layers[i-1].a.meanCol().appendCol(1.)
                    );
                    break;
                case "Conv":
                    this.backConv(
                        i,
                        i == 0 ? x : this.net.layers[i-1].a.meanCol()
                    );
                    break;
                case "MaxPooling":
                    this.backMaxPooling(i);
                    break;
                default:
                    ;
            }
        }
    }

重み更新メソッド

最終層

式はこちら。

\begin{aligned}
\delta_n&=E'(a_{last})f'(x_{last,n})\\
w_{new}&=w_{old}-\eta E'(a_{last})f'(x_{last,n})a_{pre}
\end{aligned}
GD.java
    /**
     * Doing back propagation to last layer.
     * @param x input matrix.
     * @param y Result of forward propagation.
     * @param t Answer.
     */
    protected void backLastLayer(Matrix x, Matrix y, Matrix t){
        Layer lastLayer = this.net.layers[this.layersLength-1];
        Layer preLayer = this.net.layers[this.layersLength-2];

        // E'(a)
        Matrix E = this.lossFunc.diff(lastLayer.a, t);
        // f'(x)
        Matrix f = lastLayer.actFunc.diff(lastLayer.x);
        for (int i = 0; i < lastLayer.nodesNum; i++){
            double num = 0.;
            for (int j = 0; j < x.row; j++){
                num += E.matrix[j][i] * f.matrix[j][i];
            }
            lastLayer.delta.matrix[i][0] = num;
        }

        lastLayer.w = lastLayer.w.add(lastLayer.delta.dot(preLayer.a.appendCol(1.).meanCol()).mult(-this.eta).T());
    }

最後の行が少し長くなっていますが、式通りになっているのが分かると思います。

最終層以外
\begin{aligned}
\delta_n&=\sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n})\\
w_{new}&=w_{old,n}-\eta \sum\limits_{n=1}^m \left\{\delta_{next,n}w_{next,n}\right\}f'(x_{now,n}) a_{pre}
\end{aligned}
GD.java
    /**
     * Doing back propagation.
     * @param num Number of layer.
     * @param aPre Output matrix of previous layer.
     */
    protected void backDense(int num, Matrix aPre){
        Matrix deltaNext = this.net.layers[num+1].delta;
        Matrix wNext = this.net.layers[num+1].w;
        Layer nowLayer = this.net.layers[num];

        double deltaEle = 0.;
        Matrix cal;

        // sum(delta * w)
        for (int i = 0; i < deltaNext.row; i++){
            deltaEle += wNext.getCol(i).mult(deltaNext.matrix[i][0]).sum();
        }
        for (int i = 0; i < nowLayer.nodesNum; i++){
            // 今の層のデルタ
            nowLayer.delta.matrix[i][0] = 
                deltaEle * nowLayer.actFunc.diff(nowLayer.x.getCol(i)).meanCol().matrix[0][0];
            // 重み更新
            cal = nowLayer.w.getCol(i).add(aPre.T().mult(-this.eta));
            for (int j = 0; j < nowLayer.inNum; j++){
                nowLayer.w.matrix[j][i] = cal.matrix[j][0];
            }
        }
    }

学習実行メソッド

前作同様、学習過程をコンソール出力するもの・ファイル出力するもの、テストデータの誤差も確認するもの・しないものの計四種類用意しています。

GD.java
    /**
     * Run learning.
     * @param x Input layer.
     * @param t Answer.
     * @param nEpoch Number of epoch.
     * @return Output of this network.
     */
    public Matrix fit(Matrix x, Matrix t, int nEpoch){
        Matrix y = this.net.forward(x);

        for (int i = 0; i < nEpoch; i++){
            System.out.printf("Epoch %d/%d\n", i+1, nEpoch);
            this.back(x, y, t);
            y = this.net.forward(x);
            System.out.printf("loss: %.4f\n", this.lossFunc.calc(y, t));
        }

        return y;
    }

    /**
     * Run learning.
     * @param x Input layer.
     * @param t Answer.
     * @param nEpoch Number of epoch.
     * @param valX Input layer for validation.
     * @param valT Answer for validation.
     * @return Output of this network.
     */
    public Matrix fit(Matrix x, Matrix t, int nEpoch, Matrix valX, Matrix valT){
        Matrix y = this.net.forward(x);
        Matrix valY;

        for (int i = 0; i < nEpoch; i++){
            System.out.printf("Epoch %d/%d\n", i+1, nEpoch);
            this.back(x, y, t);
            valY = this.net.forward(valX);
            y = this.net.forward(x);
            System.out.printf(
                "loss: %.4f - valLoss: %.4f\n",
                this.lossFunc.calc(y, t),
                this.lossFunc.calc(valY, valT)
            );
        }

        return y;
    }

    /**
     * Run learning.
     * @param x Input layer.
     * @param t Answer.
     * @param nEpoch Number of epoch.
     * @param fileName Name of logging file.
     * @return Output of this network.
     */
    public Matrix fit(Matrix x, Matrix t, int nEpoch, String fileName){
        Matrix y = this.net.forward(x);
        double loss = 0.;

        try(
            PrintWriter fp = new PrintWriter(fileName);
        ){
            fp.write("Epoch,loss\n");
            for (int i = 0; i < nEpoch; i++){
                this.back(x, y, t);
                y = this.net.forward(x);
                loss = this.lossFunc.calc(y, t);
                fp.printf("%d,%f\n", i+1, loss);
            }
        }catch (IOException e){
            System.out.println("IO Exception");
            System.exit(-1);
        }

        return y;
    }

    /**
     * Run learning.
     * @param x Input layer.
     * @param t Answer.
     * @param nEpoch Number of epoch.
     * @param valX Input layer for validation.
     * @param valT Answer for validation.
     * @param fileName Name of logging file.
     * @return Output of this network.
     */
    public Matrix fit(Matrix x, Matrix t, int nEpoch, Matrix valX, Matrix valT, String fileName){
        Matrix y = this.net.forward(x);
        Matrix valY;
        double loss = 0., valLoss = 0.;

        try(
            PrintWriter fp = new PrintWriter(fileName);
        ){
            fp.write("Epoch,loss,valLoss\n");
            for (int i = 0; i < nEpoch; i++){
                this.back(x, y, t);
                valY = this.net.forward(valX);
                y = this.net.forward(x);
                loss = this.lossFunc.calc(y, t);
                valLoss = this.lossFunc.calc(valY, valT);
                fp.printf("%d,%f,%f\n", i+1, loss, valLoss);
            }
        }catch (IOException e){
            System.out.println("IO Exception");
            System.exit(-1);
        }

        return y;
    }

テスト

簡単な足し算を予想してもらいました。

OptimzierTest.java
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;
import org.MyNet2.network.*;
import org.MyNet2.optimizer.*;
import org.MyNet2.lossFunc.*;
import org.MyNet2.*;

public class OptimizerTest {
    public static void main(String[] str){
        Matrix x = new Matrix(10, 2);
        for (int i = 0; i < x.row; i++){
            for (int j = 0; j < x.col; j++){
                x.matrix[i][j] = i*0.1 + j*0.05;
            }
        }
        Matrix t = new Matrix(10, 1);
        for (int i = 0; i < t.row; i++){
            t.matrix[i][0] = x.matrix[i][0] + x.matrix[i][1];
        }

        Network net = new Network(
            2,
            new Dense(6, AFType.RELU),
            new Dense(1, AFType.LINEAR)
        );
        System.out.println(net);
        GD opt = new GD(net, new MSE());
        
        opt.fit(x, t, 5);

        System.out.println(t);
        System.out.println(net.forward(x));
    }
}

実行結果はこちら。

Network
----------------------------------------------------------------
Dense
act: ReLU
2 => 6
----------------------------------------------------------------
Dense
act: Linear
6 => 1
----------------------------------------------------------------

Epoch 1/5
loss: 0.1477
Epoch 2/5
loss: 0.0866
Epoch 3/5
loss: 0.0847
Epoch 4/5
loss: 0.0846
Epoch 5/5
loss: 0.0846
[[ 0.0500 ]
 [ 0.2500 ]
 [ 0.4500 ]
 [ 0.6500 ]
 [ 0.8500 ]
 [ 1.0500 ]
 [ 1.2500 ]
 [ 1.4500 ]
 [ 1.6500 ]
 [ 1.8500 ]]

[[ 0.4689 ]
 [ 0.5957 ]
 [ 0.7225 ]
 [ 0.8244 ]
 [ 0.9169 ]
 [ 1.0094 ]
 [ 1.1019 ]
 [ 1.1944 ]
 [ 1.2868 ]
 [ 1.3793 ]]

正解と予測値はあまりあっていませんが、一応特徴はつかんでいます。
このあと色々な学習率を試しましたが、0.02あたりから発散し0.01未満だと学習が遅くなるという結果になりました。

畳み込み層とプーリング層

参考文献にも載せていますが、yusugomori/DeepLearning at devを大いに参考にさせていただきました。というより、ほぼ写しました。

畳み込み層

バイアス

\begin{align}
b_{new,k}&=b_{old,k}-\eta\Delta b_k \nonumber \\
\Delta b_k&=\sum\limits_{i=1}^{\lfloor (I-P+1)/S \rfloor}\sum\limits_{j=1}^{\lfloor (J-Q+1)/S \rfloor}\delta_{next,k,i,j}f'(x_{now,k,i,j}+b_k) \nonumber \\
\end{align}

重み

\begin{align}
w_{new,c,p,q}&=w_{old,c,p,q}-\eta\Delta w_{old,c,p,q} \nonumber \\
\Delta w_{old,c,p,q}&=\sum\limits_{c=1}^C\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\left[\sum\limits_{i=1}^{\lfloor (I-P+1)/S \rfloor}\sum\limits_{j=1}^{\lfloor (J-Q+1)/S \rfloor}\left\{\delta_{next,c,i,j}f'(x_{now,c,i,j}+b)\right\}a_{pre,c,i+p,j+q}\right]\nonumber
\end{align}

$\delta$

\begin{align}
\delta_{c,i,j}&=\sum\limits_{k=1}^K\sum\limits_{p=1}^P\sum\limits_{q=1}^Q\left\{\delta_{next,k,i-P-p+1,j-Q-q+1}f'(x_{k,i-P-p+1,j-Q-q+1}+b_k)w_{c,k,p,q}\right\} \nonumber
\end{align}

ここでは計算用行列gradBとgradWを作成し、それを使用しています。
それぞれ、上式の$\Delta b_n,\Delta w_{old,n.p,q}$のことです。

GD.java
    /**
     * Doing back propagation.
     * @param num Number of layer.
     * @param aPre Output matrix of previous layer.
     */
    protected void backConv(int num, Matrix aPre){
        Matrix deltaNext = this.net.layers[num+1].delta;
        Layer nowLayer = this.net.layers[num];
        Matrix xMeanCol = nowLayer.x.meanCol();
        int cMult = nowLayer.wRow * nowLayer.wCol;
        int kMult = nowLayer.outRow * nowLayer.outCol;
        int iMult = nowLayer.outCol;

        Matrix gradW = new Matrix(nowLayer.w.row, nowLayer.w.col);
        Matrix gradB = new Matrix(nowLayer.b.row, nowLayer.b.col);

        // gradBとgradWを計算
        for (int k = 0; k < nowLayer.kernelNum; k++){
            double d = 0.;
            // f'(x_{now,n,i,j}+b_n) ※gradBでもgradWでも使う
            Matrix fD = nowLayer.actFunc.diff(xMeanCol.add(nowLayer.b.matrix[k][0]));

            for (int i = 0; i < nowLayer.outRow; i++){
                for (int j = 0; j < nowLayer.outCol; j++){
                    // δ_{next,n,i,j}*f'(x_{now,n,i,j}+b_n)
                    d = deltaNext.matrix[k][iMult*i + j] * fD.matrix[0][kMult*k + iMult*i + j];
                    gradB.matrix[k][0] += d;

                    for (int c = 0; c < nowLayer.channelNum; c++){
                        for (int p = 0; p < nowLayer.wRow; p++){
                            for (int q = 0; q < nowLayer.wCol; q++){
                                // sum(δf'(x+b)*a_{pre,c,i+p,j+q}) ※δ、x、bの下付き文字は省略
                                gradW.matrix[k][cMult*c + nowLayer.wCol*p + q] += d * aPre.matrix[c][nowLayer.inCol*(i+p) + j+q];
                            }
                        }
                    }
                }
            }
        }

        // バイアスと重み更新
        nowLayer.b = nowLayer.b.add(gradB.mult(-this.eta));
        nowLayer.w = nowLayer.w.add(gradW.mult(-this.eta));

        double d;
        for (int k = 0; k < nowLayer.kernelNum; k++){
            // f'(x_{k,i-P-p*1,j-Q-q+1}+b_k)
            Matrix fD = nowLayer.actFunc.diff(xMeanCol.add(nowLayer.b.matrix[k][0]));
            for (int i = 0; i < nowLayer.inRow; i++){
                for (int j = 0; j < nowLayer.inCol; j++){
                    for (int c = 0; c < nowLayer.channelNum; c++){
                        for (int p = 0; p < nowLayer.wRow; p++){
                            for (int q = 0; q < nowLayer.wCol; q++){
                                if ((i - (nowLayer.wRow-1) - p < 0) || (j - (nowLayer.wCol - 1) - q < 0)){
                                    // はみ出したら
                                    d = 0.;
                                }else{
                                    // δ_{next,k,i-P--p+1,j-Q-q+1}*f'(x+b)*w_{k,c,p,q} ※x、bの下付き文字は省略
                                    d = deltaNext.matrix[k][iMult*(i-(nowLayer.wRow-1)-p) + j-(nowLayer.wCol-1)-q]
                                        * fD.matrix[0][kMult*k + iMult*(i-(nowLayer.wRow-1)-p) + j-(nowLayer.wCol-1)-q]
                                        * nowLayer.w.matrix[k][cMult*c + nowLayer.wCol*p + q];
                                }
                                nowLayer.delta.matrix[c][iMult*i + j] += d;
                            }
                        }
                    }
                }
            }
        }
    }

プーリング層

次の層から送られてきたデルタを、前の層で扱いやすい形にします。

GD.java
    /**
     * Doing back propagation.
     * @param num Number of layer.
     */
    protected void backMaxPooling(int num){
        Matrix aPre = this.net.layers[num-1].a.meanCol();
        Matrix deltaNext = this.net.layers[num+1].delta;
        Layer nowLayer = this.net.layers[num];
        int kMult = nowLayer.outRow * nowLayer.outCol;
        int iMult = nowLayer.outCol;
        int poolSize = nowLayer.poolSize;

        for (int k = 0; k < nowLayer.kernelNum; k++){
            for (int i = 0; i < nowLayer.outRow; i++){
                for (int j = 0; j < nowLayer.outCol; j++){
                    for (int p = 0; p < poolSize; p++){
                        for (int q = 0; q < poolSize; q++){
                            if (nowLayer.a.matrix[0][kMult*k + iMult*i + j]
                            == aPre.matrix[0][kMult*k + iMult*(poolSize*i+p) + poolSize*j+q]){
                                // マックスプーリングしたところ
                                nowLayer.delta.matrix[k][iMult*(poolSize*i+p) + poolSize*j+q] = deltaNext.matrix[k][iMult*i + j];
                            }else{
                                // その他
                                nowLayer.delta.matrix[k][iMult*(poolSize*i+p) + poolSize*j+q] = 0.;
                            }
                        }
                    }
                }
            }
        }
    }

テスト

以前も使用したこちらのネットワークで、合計を求めるモデルを作成します。
ついでに更新後の畳み込み層の重み行列もプリントしています。

OptimizerTest.java
import java.util.Random;
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;
import org.MyNet2.network.*;
import org.MyNet2.optimizer.*;
import org.MyNet2.lossFunc.*;
import org.MyNet2.*;

public class OptimizerTest {
    public static void main(String[] str){
        Matrix4d x = new Matrix4d(new int[]{10, 2, 4, 4}, new Random(0));
        Matrix t = new Matrix(10, 1);
        for (int i = 0; i < t.row; i++){
            Matrix cal = x.flatten().getRow(i);
            t.matrix[i][0] = cal.sum() / cal.col;
        }

        Network net = new Network(
            2, 4, 4,
            new Conv(4, new int[]{3, 3}, AFType.RELU),
            new MaxPooling(2),
            new Dense(4, AFType.RELU),
            new Dense(1, AFType.LINEAR)
        );
        net.summary();

        GD opt = new GD(net, new MSE());
        opt.fit(x.flatten(), t, 5);

        System.out.println(t);
        System.out.println(net.forward(x.flatten()));

        Layer conv = net.layers[0];
        System.out.println(conv.w.toMatrix4d(conv.kernelNum, conv.channelNum, conv.wRow, conv.wCol));
    }
}

実行結果はこちら。

Network
----------------------------------------------------------------
Convolution
act: ReLU
2, 4, 4 => (3, 3) => 4, 2, 2
----------------------------------------------------------------
MaxPooling
act: null
4, 2, 2 => (2, 2) => 4, 1, 1
----------------------------------------------------------------
Dense
act: ReLU
4 => 4
----------------------------------------------------------------
Dense
act: Linear
4 => 1
----------------------------------------------------------------

Epoch 1/5
loss: 46631.8458
Epoch 2/5
loss: 1922131256367369000000000.0000
Epoch 3/5
loss: 66252695493432280000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0000
Epoch 4/5
loss: Infinity
Epoch 5/5
loss: NaN
[[ 0.5700 ]
 [ 0.5635 ]
 [ 0.4899 ]
 [ 0.5116 ]
 [ 0.4569 ]
 [ 0.4343 ]
 [ 0.5090 ]
 [ 0.4424 ]
 [ 0.4937 ]
 [ 0.4404 ]]

[[ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]
 [ NaN ]]

[[[[ 0.4619 -0.5189  0.2748 ]
   [ 0.1009  0.1951 -0.3336 ]
   [-0.2296  0.9697  0.7584 ]]
  [[ 0.8825 -0.4501 -0.7422 ]
   [-0.7068 -0.9535  0.0935 ]
   [ 0.9290 -0.7910  0.2503 ]]]
 [[[-0.1784  0.5526  0.9814 ]
   [-0.0255  0.4925  0.4663 ]
   [ 0.6346  0.6778  0.0534 ]]
  [[ 0.7987 -0.7321 -0.8339 ]
   [ 0.9571  0.4447  0.4301 ]
   [-0.7136 -0.0741 -0.9910 ]]]
 [[[-0.8570 -0.3032 -0.3225 ]
   [ 0.7187  0.9431  0.7315 ]
   [ 0.2252 -0.6420 -0.5649 ]]
  [[ 0.7090 -0.9807  0.3846 ]
   [ 0.5426  0.4254 -0.5775 ]
   [ 0.5662  0.8907 -0.9715 ]]]
 [[[-0.2116  0.7076  0.5721 ]
   [ 0.9869  0.7662 -0.6594 ]
   [ 0.9241  0.4486  0.3547 ]]
  [[ 0.6088 -0.1171 -0.0758 ]
   [ 0.7057  0.0037  0.9839 ]
   [ 0.9385 -0.2938 -0.9055 ]]]]

発散しました。
学習率を0.002に変更したところ以下の結果が得られました。

Network
----------------------------------------------------------------
Convolution
act: ReLU
2, 4, 4 => (3, 3) => 4, 2, 2
----------------------------------------------------------------
MaxPooling
act: null
4, 2, 2 => (2, 2) => 4, 1, 1
----------------------------------------------------------------
Dense
act: ReLU
4 => 4
----------------------------------------------------------------
Dense
act: Linear
4 => 1
----------------------------------------------------------------

Epoch 1/5
loss: 0.0762
Epoch 2/5
loss: 0.0111
Epoch 3/5
loss: 0.0110
Epoch 4/5
loss: 0.0110
Epoch 5/5
loss: 0.0110
[[ 0.5700 ]
 [ 0.5635 ]
 [ 0.4899 ]
 [ 0.5116 ]
 [ 0.4569 ]
 [ 0.4343 ]
 [ 0.5090 ]
 [ 0.4424 ]
 [ 0.4937 ]
 [ 0.4404 ]]

[[ 0.3863 ]
 [ 0.4842 ]
 [ 0.5542 ]
 [ 0.4434 ]
 [ 0.4617 ]
 [ 0.5778 ]
 [ 0.5192 ]
 [ 0.5500 ]
 [ 0.3745 ]
 [ 0.5602 ]]

[[[[ 0.4619 -0.5189  0.2748 ]
   [ 0.1009  0.1951 -0.3336 ]
   [-0.2296  0.9697  0.7584 ]]
  [[ 0.8825 -0.4501 -0.7422 ]
   [-0.7068 -0.9535  0.0935 ]
   [ 0.9290 -0.7910  0.2503 ]]]
 [[[-0.1784  0.5526  0.9814 ]
   [-0.0255  0.4925  0.4663 ]
   [ 0.6346  0.6778  0.0534 ]]
  [[ 0.7987 -0.7321 -0.8339 ]
   [ 0.9571  0.4447  0.4301 ]
   [-0.7136 -0.0741 -0.9910 ]]]
 [[[-0.8570 -0.3032 -0.3225 ]
   [ 0.7187  0.9431  0.7315 ]
   [ 0.2252 -0.6420 -0.5649 ]]
  [[ 0.7090 -0.9807  0.3846 ]
   [ 0.5426  0.4254 -0.5775 ]
   [ 0.5662  0.8907 -0.9715 ]]]
 [[[-0.2116  0.7076  0.5721 ]
   [ 0.9869  0.7662 -0.6594 ]
   [ 0.9241  0.4486  0.3547 ]]
  [[ 0.6088 -0.1171 -0.0758 ]
   [ 0.7057  0.0037  0.9839 ]
   [ 0.9385 -0.2938 -0.9055 ]]]]

誤差は少ないですが、予測結果は微妙ですね。

フルバージョン

次回は

Adamなど他の最適化手法を実装します。

次回

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0