LoginSignup
0
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦53 ~損失関数~

Last updated at Posted at 2022-04-18

前回

今回の目標

損失関数を作成する。

ここから本編

Softmax

損失関数追加の前に、活性化関数としてソフトマックス関数を追加します。
その層のノード数をN、現在着目しているノード番号をiとおくと、

\begin{align}
f(x_i)&=\frac{e^{x_i}}{\sum\limits_{n=1}^Ne^{x_n}}\nonumber \\
f'(x_i)&=\left\{\begin{aligned}
&f(x_i)(1-f(x_j))\ &if\ i=j\\
&-f(x_i)f(x_j)\ &if\ i\ne j
\end{aligned}\right.\nonumber
\end{align}
プログラム
Softmax.java
package org.MyNet2.actFunc;

import java.lang.Math;
import org.MyNet2.*;

/**
 * Softmax function.
 */
public class Softmax extends ActivationFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public Softmax(){
        ;
    }

    /**
     * Execute this actiation function.
     * @param in linear transformationed matrix.
     * @return output matrix.
     */
    @Override
    public Matrix calc(Matrix in){
        Matrix rtn = new Matrix(in.row, in.col);
        double denominator;
        double cal;
        for (int i = 0; i < rtn.row; i++){
            denominator = 0.;
            for (int j = 0; j < rtn.col; j++){
                cal = Math.exp(in.matrix[i][j]);
                rtn.matrix[i][j] = cal;
                denominator += cal;
            }
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] /= denominator;
            }
        }

        return rtn;
    }

    /**
     * Calcurate this activation function's differential.
     * @param in Matrix of input.
     * @return The result of differentiating this activation function.
     */
    @Override
    public Matrix diff(Matrix in){
        Matrix rtn = new Matrix(in.row, in.col);
        double denominator;
        double cal;
        for (int i = 0; i < rtn.row; i++){
            denominator = 0.;
            for (int j = 0; j < rtn.col; j++){
                cal = Math.exp(in.matrix[i][j]);
                rtn.matrix[i][j] = cal;
                denominator += cal;
            }
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] /= denominator;
                rtn.matrix[i][j] = rtn.matrix[i][j] * (1-rtn.matrix[i][j]);
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        return "Softmax";
    }
}

現在のディレクトリ構成

MyNet2
├── actFunc
│   └── 活性化関数クラス
├── layer
│   └── 層クラス
├── lossFunc
│   └── 損失関数クラス
├── network
│   └── ネットワーククラス
├── tests
│   └── テストクラス
└── 行列クラス

MeanAbsoluteError

平均絶対誤差。
ネットワークの出力をy、正解データをt、データ数Nとおくと以下のようになります(誤差計算はスカラ、微分結果は行列になるようになっています)。

\begin{align}
E(y)&=\frac{1}{N}\sum\limits_{n=1}^N|y_n-t_n|\nonumber \\
E'(y)&=\left\{\begin{aligned}
1\ if\ y_n-t_n>0\\
-1 if\ y_n-t_n<0
\end{aligned}\right.\nonumber
\end{align}
プログラム
MAE.java
package org.MyNet2.lossFunc;

import java.lang.Math;
import org.MyNet2.*;

/**
 * Class for loss function.
 */
public class MAE extends LossFunction {
    /**
     * Constructor for this class.
     */
    public MAE(){
        ;
    }

    /**
     * Calcurate this loss function.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return Diiference between y and b.
     */
    @Override
    public double calc(Matrix y, Matrix t){
        double rtn = 0.;

        for (int i = 0; i < y.row; i++){
            for (int j = 0; j < y.col; j++){
                rtn += Math.abs(y.matrix[i][j] - t.matrix[i][j]);
            }
        }

        return rtn / y.row;
    }

    /**
     * Calcurate this loss function's differential.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return The result of differentialting the difference between y and t.
     */
    @Override
    public Matrix diff(Matrix y, Matrix t){
        Matrix rtn = new Matrix(y.row, y.col);

        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                if (y.matrix[i][j] - t.matrix[i][j] > 0){
                    rtn.matrix[i][j] = 1.;
                }else{
                    rtn.matrix[i][j] = -1.;
                }
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        return "MeanAbsoluteError";
    }
}

MeanSquaredError

平均二乗誤差。

\begin{align}
E(y)&=\frac{1}{N}\sum\limits_{n=1}^N(y_n-t_n)^2\nonumber \\
E'(y)&=2(y-t)\nonumber
\end{align}
プログラム
MSE.java
package org.MyNet2.lossFunc;

import java.lang.Math;
import org.MyNet2.*;

/**
 * Class for loss function.
 */
public class MSE extends LossFunction {
    /**
     * Constructor for this class.
     */
    public MSE(){
        ;
    }

    /**
     * Calcurate this loss function.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return Diiference between y and b.
     */
    @Override
    public double calc(Matrix y, Matrix t){
        double rtn = 0.;

        for (int i = 0; i < y.row; i++){
            for (int j = 0; j < y.col; j++){
                rtn.matrix[i][0] += Math.pow(y.matrix[i][0] - t.matrix[i][0], 2);
            }
        }

        return rtn / y.row;
    }

    /**
     * Calcurate this loss function's differential.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return The result of differentialting the difference between y and t.
     */
    @Override
    public Matrix diff(Matrix y, Matrix t){
        Matrix rtn = new Matrix(y.row, y.col);

        for (int i = 0; i < y.row; i++){
            for (int j = 0; j < y.col; j++){
                rtn.matrix[i][j] = (y.matrix[i][0] - t.matrix[i][0]) * 2;
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        return "MeanSquaredError";
    }
}

CategoricalCrossEntropy

多クラス交差エントロピー。クラス数をCとする。

\begin{align}
E(y)&=-\frac{1}{N}\sum\limits_{n=1}^{N}\sum\limits_{c=1}^{C}t_{n,c}\log (y_{n,c})\nonumber \\
E'(y)&=-\frac{t}{y}\nonumber
\end{align}
プログラム
CCE.java
package org.MyNet2.lossFunc;

import java.lang.Math;
import org.MyNet2.*;

/**
 * Class for loss function.
 */
public class CCE extends LossFunction {
    /**
     * Constructor for this class.
     */
    public CCE(){
        ;
    }

    /**
     * Calcurate this loss function.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return Diiference between y and b.
     */
    @Override
    public double calc(Matrix y, Matrix t){
        double rtn = 0.;

        for (int i = 0; i < y.row; i++){
            for (int j = 0; j < y.col; j++){
                rtn += -t.matrix[i][j] * Math.log(y.matrix[i][j]);
            }
        }

        return rtn / y.row;
    }

    /**
     * Calcurate this loss function's differential.
     * @param y Matrix of network's output.
     * @param t Matrix of actual data.
     * @return The result of differentialting the difference between y and t.
     */
    @Override
    public Matrix diff(Matrix y, Matrix t){
        Matrix rtn = new Matrix(y.row, y.col);

        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = -t.matrix[i][j] / y.matrix[i][j];
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        return "CategoricalCrossEntropy";
    }
}

フルバージョン

次回は

誤差逆伝播を実装します。

次回

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0