0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

コンピュータとオセロ対戦43 ~目的関数~

Last updated at Posted at 2022-02-15

前回

今回の目標

ここから本編

修正点

パッケージ名out_functionをactivationFunctionに変更しました。
また、Matrixクラスに以下のメソッドを追加しました。

    /**
     * Return absolute value of this matrix.
     * @param in Matrix.
     * @return Absolute value of this matrx.
     */
    public void abs(){
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = Math.abs(this.matrix[i][j]);
            }
        }
    }

    /**
     * Return absolute value of a matrix.
     * @param in Matrix.
     * @return Absolute value of a matrx.
     */
    public static Matrix abs(Matrix in){
        Matrix rtn = in.clone();

        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = Math.abs(in.matrix[i][j]);
            }
        }

        return rtn;
    }

    /**
     * Calcurate average of each columns.
     * @return Matrix instance that had everage of each columns in this matrix.
     */
    public Matrix meanCol(){
        Matrix rtn = new Matrix(new double[1][this.col]);

        double num = 0;
        for (int j = 0; j < this.col; j++){
            num = 0;
            for (int i = 0; i < this.row; i++){
                num += this.matrix[i][j];
            }
            rtn.matrix[0][j] = num / this.row;
        }

        return rtn;
    }

    /**
     * Calcurate average of each columns.
     * @param in A Matrix instance.
     * @return Matrix instance that had everage of each columns in a matrix.
     */
    public static Matrix meanCol(Matrix in){
        Matrix rtn = new Matrix(new double[1][in.col]);

        double num = 0;
        for (int j = 0; j < in.col; j++){
            num = 0;
            for (int i = 0; i < in.row; i++){
                num += in.matrix[i][j];
            }
            rtn.matrix[0][j] = num / in.row;
        }

        return rtn;
    }

    /**
     * Calucurate square root each number of this matrix.
     * @return New Matrix instance.
     */
    public Matrix sqrt(){
        Matrix rtn = new Matrix(new double[this.row][this.col]);

        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = Math.sqrt(this.matrix[i][j]);
            }
        }

        return rtn;
    }

    /**
     * Calucurate square root each number of a matrx.
     * @param in Matrix instance.
     * @return New Matrix instance.
     */
    public static Matrix sqrt(Matrix in){
        Matrix rtn = new Matrix(new double[in.row][in.col]);

        for (int i = 0; i < in.row; i++){
            for (int j = 0; j < in.col; j++){
                rtn.matrix[i][j] = Math.sqrt(in.matrix[i][j]);
            }
        }

        return rtn;
    }

現在のディレクトリ構成

MyNet
├── costFunction  // 目的関数パッケージ(今回作成)
├── layer         // 層パッケージ(すでに作成)
├── matrix        // 行列パッケージ(すでに作成)
├── network       // ネットワークパッケージ(すでに作成)
├── nodes         // ノードパッケージ(すでに作成)
│   └── activationFunction // 活性化関数パッケージ(すでに作成)
└── optimzer      // 最適化関数パッケージ(今は空)

CF.java

活性化関数の時のように、列挙型で目的関数を管理したいと思い作りました。
とりあえず平均二乗誤差と平均絶対誤差。自分は回帰しかしないと思うので。

package costFunction;

/**
 * Enum class for designating cost function.
 * AF is a word omitted "Cost Function".
 */
public enum CF {
    MSE,
    MAE,
    RMSE
}

CostFunction.java

目的関数の親クラス。
メソッドの中身は適当です。

package costFunction;

import matrix.*;

/**
 * Cost function's base class.
 * All cost functions must extend this class.
 */
public class CostFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public CostFunction(){
        ;
    }

    /**
     * Calcurate this cost function.
     * @param y Matrix of network's output. 
     * @param t Matrix of actual data.
     * @return Difference between y and t.
     */
    public Matrix calcurate(Matrix y, Matrix t){
        return Matrix.sub(y, t);
    }

    /**
     * Calcurate this cost function's differential.
     * @param y Matrix of network's output. 
     * @param t Matrix of actual data.
     * @return The result of differentiating the difference between y and t.
     */
    public Matrix differential(Matrix y, Matrix t){
        return Matrix.sub(y, t);
    }
}

MeanAbsoluteError.java

平均絶対誤差。

$$ E(w) = \frac{1}{N}\Sigma^{N}_{n=1}|y_n-t_n| $$

ここで、正解が$t$、予測値が$y$、データ量が$N$、誤差が$E$です。

package costFunction;

import matrix.*;

/**
 * Cost function's base class.
 * All cost functions must extend this class.
 */
public class MeanAbsoluteError extends CostFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public MeanAbsoluteError(){
        ;
    }

    /**
     * Calcurate this cost function.
     * @param y Matrix of network's output. 
     * @param t Matrix of actual data.
     * @return MSE between y and t.
     */
    public Matrix calcurate(Matrix y, Matrix t){
        Matrix rtn = Matrix.abs(Matrix.sub(y, t));
        return rtn.meanCol();
    }
}

MeanSquaredError.java

平均二乗誤差。

$$ E(w) = \frac{1}{N}\Sigma^{N}_{n=1}(y_n-t_n)^2 $$

package costFunction;

import matrix.*;

/**
 * Cost function's base class.
 * All cost functions must extend this class.
 */
public class MeanSquaredError extends CostFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public MeanSquaredError(){
        ;
    }

    /**
     * Calcurate this cost function.
     * @param y Matrix of network's output. 
     * @param t Matrix of actual data.
     * @return MAE between y and t.
     */
    public Matrix calcurate(Matrix y, Matrix t){
        Matrix rtn = Matrix.abs(Matrix.sub(y, t));

        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] *= rtn.matrix[i][j];
            }
        }

        return rtn.meanCol();
    }
}

RootMeanSquaredError.java

平均平方二乗誤差。

$$ E(w) = \frac{1}{N}\sqrt{\Sigma^{N}_{n=1}(y_n-t_n)^2} $$

package costFunction;

import matrix.*;
import java.lang.Math;

/**
 * Cost function's base class.
 * All cost functions must extend this class.
 */
public class RootMeanSquaredError extends CostFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public RootMeanSquaredError(){
        ;
    }

    /**
     * Calcurate this cost function.
     * @param y Matrix of network's output. 
     * @param t Matrix of actual data.
     * @return RMSE between y and t.
     */
    public Matrix calcurate(Matrix y, Matrix t){
        Matrix rtn = new Matrix(new double[1][y.col]);
        double num;

        for (int j = 0; j < y.col; j++){
            num = 0;
            for (int i = 0; i < y.row; i++){
                num += Math.pow(y.matrix[i][j] - t.matrix[i][j], 2);
            }
            rtn.matrix[0][j] = Math.sqrt(num);
        }

        return Matrix.div(rtn, y.row);
    }
}

test.java

今回作ったクラスを試してみました。

import matrix.Matrix;
import costFunction.*;

public class test {
    public static void main(String[] str){
        double[][] m = new double[5][10];
        Matrix a = new Matrix(m);
        Matrix b = new Matrix(m);

        for (int i = 0; i < a.row; i++){
            for (int j = 0; j < a.col; j++){
                a.matrix[i][j] = (double)i*1.5 + (double)j*0.5;
                b.matrix[i][j] = (double)i*1.4 + (double)j*0.6;
            }
        }

        System.out.println(a);
        System.out.println(b);
        CostFunction f = new MeanAbsoluteError();
        System.out.println(f.calcurate(a, b));
        f = new MeanSquaredError();
        System.out.println(f.calcurate(a, b));
        f = new RootMeanSquaredError();
        System.out.println(f.calcurate(a, b));
    }
}

実行結果はこちら。

[[0.0000 0.5000 1.0000 1.5000 2.0000 2.5000 3.0000 3.5000 4.0000 4.5000 ]
 [1.5000 2.0000 2.5000 3.0000 3.5000 4.0000 4.5000 5.0000 5.5000 6.0000 ]
 [3.0000 3.5000 4.0000 4.5000 5.0000 5.5000 6.0000 6.5000 7.0000 7.5000 ]
 [4.5000 5.0000 5.5000 6.0000 6.5000 7.0000 7.5000 8.0000 8.5000 9.0000 ]
 [6.0000 6.5000 7.0000 7.5000 8.0000 8.5000 9.0000 9.5000 10.0000 10.5000 ]]

[[0.0000 0.6000 1.2000 1.8000 2.4000 3.0000 3.6000 4.2000 4.8000 5.4000 ]
 [1.4000 2.0000 2.6000 3.2000 3.8000 4.4000 5.0000 5.6000 6.2000 6.8000 ]
 [2.8000 3.4000 4.0000 4.6000 5.2000 5.8000 6.4000 7.0000 7.6000 8.2000 ]
 [4.2000 4.8000 5.4000 6.0000 6.6000 7.2000 7.8000 8.4000 9.0000 9.6000 ]
 [5.6000 6.2000 6.8000 7.4000 8.0000 8.6000 9.2000 9.8000 10.4000 11.0000 ]]

[[0.2000 0.1400 0.1200 0.1400 0.2000 0.3000 0.4000 0.5000 0.6000 0.7000 ]]

[[0.0600 0.0300 0.0200 0.0300 0.0600 0.1100 0.1800 0.2700 0.3800 0.5100 ]]

[[0.1095 0.0775 0.0632 0.0775 0.1095 0.1483 0.1897 0.2324 0.2757 0.3194 ]]

Excelで試したところ、以下のようになりました。

image.png

同じ結果になっていることが分かります。

次回は

最適化関数を作りたいと思います。
自分の中で理解が最も浅い部分なので、まずは勾配降下法のみのような形になると思います。

次回

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?