LoginSignup
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦51 ~層~

Last updated at Posted at 2022-04-11

前回

今回の目標

全結合層、畳み込み層、プーリング層を作る

ここから本編

現在のディレクトリ構成

MyNet2
├── Matrix.java
├── actFunc
│   └── 活性化関数クラス
├── layer
│   └── 層クラス
└── tests
    └── テストクラス

actFunc

活性化関数についてはMyNetとほぼ同じですので、省略します。

Layer

層について、基本的な定義づけを行います。
まず層とは、ここではノードが一列に並んだものを指すことにします。ただ、これまでのように入力層や出力層を特別扱いすることはありません。また、畳み込み層とプーリング層も層です。
MyNetはノードクラスを用いたことで畳み込み層やドロップアウトなどの作成が困難でしたが、MyNet2ではネットワークの最小単位が層になるので、これらも作ってみたいと思います。

まず、層の親クラス。ほぼ空ですが、全ての層に共通する事柄をまとめています。

プログラム
Layer.java
package org.MyNet2.layer;

import java.io.Serializable;
import org.MyNet2.*;
import org.MyNet2.actFunc.*;

/**
 * Class for layer.
 */
public class Layer implements Serializable {
    /** Type of activation function for this layer. */
    public AFType afType;
    /** Activation function of this layer. */
    public ActivationFunction actFunc;
    /** Name of this layer's activation function. */
    public String actFuncName;
    /** Name of this layer */
    public String name = null;

    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public Layer(){
        ;
    }


    /**
     * Doing forward propagation.
     * @param in input matrix.
     * @return Matrix instance of output.
     */
    public Matrix forward(Matrix in){
        return in.clone();
    }

    /**
     * Doing back propagation.
     */
    public void back(){
        ;
    }

    /**
     * Calucrate delta each nodes.
     */
    public void calDelta(){
        ;
    }

    @Override
    public String toString(){
        return "";
    }
}

Dense

全結合層。
入力データは、DataFrameやDataBaseと同様の方式の二次元行列であるとします。つまり、列名が説明変数名、行名がデータ番号です。

例えば、以下のようになるとします。

image.png

ここで、バイアスは必ず1になるものとし、この1に重みをかけます。その重みを調節することで間接的にバイアスを調整することにします。これは今までの方針と同様です。

層が持つ各ノードの重みについても、二次元行列で保持することにします。
ノード数が7で、説明変数が5個の時、重みは以下のようになります。

image.png

各ノードで、各説明変数とバイアスにかける重み6個、ノードが7個なので7x6の42個の倍精度実数を保持します。
そして、入力行列と重み行列を掛け合わせることで線形変換後の値が得られます。

image.png

小さな数字が出ているのは、有効桁数の問題だと思うので気にしないでください。
活性化関数がReLuのとき、この層の出力は以下のようになります(0に近い数字は0と表記しています)。

image.png

これをJavaで作っていきたいと思います。

コンストラクタとフィールド

全結合層は活性化関数に加え、重み、入力数、ノード数、線形変換後の値、出力値をフィールドとして保持します。
また、後々作成するネットワーククラスにて具体的な値を入れようと思うので、ノード数のみを指定し重みの初期化は行わないコンストラクタも用意します。

Dense.java
package org.MyNet2.layer;

import java.util.Random;
import org.MyNet2.*;
import org.MyNet2.actFunc.*;

/**
 * Class for dense layer.
 */
public class Dense extends Layer {
    /** The list of weight of this layer. */
    public Matrix w;
    /** Number of inputs contain bias. */
    public int inNum;
    /** Number of nodes of this class. */
    public int nodesNum;
    /** Liner transformed matrix. */
    public Matrix x;
    /** Matrix of output from this layer. */
    public Matrix a;

    /**
     * Constructor for this class.
     * @param nodesNum Number of nodes.
     * @param afType Type of activation function.
     */
    public Dense(int nodesNum, AFType afType){
        this.name = "Dense";
        this.nodesNum = nodesNum;
        this.afType = afType;
    }

    /**
     * Constructor for this class.
     * @param inNum Number of inputs don't contain bias.
     * @param nodesNum Number of nodes of this class.
     * @param afType Type of activation function for this layer.
     */
    public Dense(int inNum, int nodesNum, AFType afType){
        this.setup(inNum, nodesNum, afType, 0);
    }

    /**
     * Constructor for this class.
     * @param inNum Number of inputs don't contain bias.
     * @param nodesNum Number of nodes of this class.
     * @param afType Type of activation function for this layer.
     * @param seed Number of seed for random class.
     */
    public Dense(int inNum, int nodesNum, AFType afType, long seed){
        this.setup(inNum, nodesNum, afType, seed);
    }

    /**
     * Construct instead of constructor.
     * @param inNum Number of inputs don't contain bias.
     * @param nodesNum Number of nodes of this class.
     * @param afType Type of activation function for this layer.
     * @param seed Number of seed for random class.
     */
    protected void setup(int inNum, int nodesNum, AFType afType, long seed){
        this.name = "Dense";
        this.inNum = inNum + 1;
        this.nodesNum = nodesNum;
        
        this.w = new Matrix(this.inNum, nodesNum, new Random(seed), -1, 1);
        this.x = new Matrix(this.inNum, this.nodesNum);
        this.a = new Matrix(this.inNum, this.nodesNum);

        switch(afType) {
        case SIGMOID:
            this.actFunc = new Sigmoid();
            break;
        case RELU:
            this.actFunc = new ReLu();
            break;
        case TANH:
            this.actFunc = new Tanh();
            break;
        case LINER:
            this.actFunc = new Liner();
            break;
        default:
            System.out.println("ERROR: The specified activation function is wrong");
            System.exit(-1);
        }
        this.actFuncName = this.actFunc.toString();
    }

順方向計算

ノードクラスがないため、順方向計算は以下のプログラムで済みます。

Dense.java
    @Override
    public Matrix forward(Matrix in){
        Matrix in_ = in.appendCol(1.0);
        return this.actFunc.calc(in_.dot(this.w));
    }

toString

Dense.java
    @Override
    public String toString(){
        String str = String.format(
            "----------------------------------------------------------------\n"
            + "Dense\n"
            + "nodes num: %d, activation function: %s", this.nodesNum, this.actFuncName
        );

        return str;
    }

テスト

Excelで示した上記データを実際に入れてみたいと思います。

DenseTest.java
import org.MyNet2.*;
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;

public class DenseTest {
    public static void main(String[] str){
        // 全結合層作成と、重みの初期化
        // テストのためにこうやって重みを初期化しているが、学習するときは乱数で初期化する。
        Dense layer = new Dense(5, 7, AFType.RELU);
        System.out.println(layer);
        for (int i = 0; i < layer.w.row; i++){
            for (int j = 0; j < layer.w.col; j++){
                layer.w.matrix[i][j] = i * 0.2 - j * 0.1;
            }
        }
        System.out.println(layer.w);

        // 入力データの用意
        Matrix in = new Matrix(10, 5);
        for (int i = 0; i < in.row; i++){
            for (int j = 0; j < in.col; j++){
                in.matrix[i][j] = i * 0.1 - j * 0.2;
            }
        }
        System.out.println(in);

        // 入力データを実際に入れ、順方向計算
        System.out.println(layer.forward(in));
    }
}

実行結果はこちら。

----------------------------------------------------------------
Dense
nodes num: 7, activation function: ReLu
[[ 0.0000 -0.1000 -0.2000 -0.3000 -0.4000 -0.5000 -0.6000 ]
 [ 0.2000  0.1000  0.0000 -0.1000 -0.2000 -0.3000 -0.4000 ]
 [ 0.4000  0.3000  0.2000  0.1000  0.0000 -0.1000 -0.2000 ]
 [ 0.6000  0.5000  0.4000  0.3000  0.2000  0.1000  0.0000 ]
 [ 0.8000  0.7000  0.6000  0.5000  0.4000  0.3000  0.2000 ]
 [ 1.0000  0.9000  0.8000  0.7000  0.6000  0.5000  0.4000 ]]

[[ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
 [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
 [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 ]
 [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 ]
 [ 0.4000  0.2000  0.0000 -0.2000 -0.4000 ]
 [ 0.5000  0.3000  0.1000 -0.1000 -0.3000 ]
 [ 0.6000  0.4000  0.2000  0.0000 -0.2000 ]
 [ 0.7000  0.5000  0.3000  0.1000 -0.1000 ]
 [ 0.8000  0.6000  0.4000  0.2000  0.0000 ]
 [ 0.9000  0.7000  0.5000  0.3000  0.1000 ]]

[[ 0.0000  0.0000  0.0000  0.1000  0.2000  0.3000  0.4000 ]
 [ 0.0000  0.0500  0.1000  0.1500  0.2000  0.2500  0.3000 ]
 [ 0.2000  0.2000  0.2000  0.2000  0.2000  0.2000  0.2000 ]
 [ 0.4000  0.3500  0.3000  0.2500  0.2000  0.1500  0.1000 ]
 [ 0.6000  0.5000  0.4000  0.3000  0.2000  0.1000  0.0000 ]
 [ 0.8000  0.6500  0.5000  0.3500  0.2000  0.0500  0.0000 ]
 [ 1.0000  0.8000  0.6000  0.4000  0.2000  0.0000  0.0000 ]
 [ 1.2000  0.9500  0.7000  0.4500  0.2000  0.0000  0.0000 ]
 [ 1.4000  1.1000  0.8000  0.5000  0.2000  0.0000  0.0000 ]
 [ 1.6000  1.2500  0.9000  0.5500  0.2000  0.0000  0.0000 ]]

同じ結果になっていることが分かると思います。

Matrix3d

畳み込み層を作る前に、三次元と四次元の行列クラスを作ります。
四次元行列クラス作成のため、三次元行列クラスを作成します。
とりあえずコンストラクタとcloneメソッドのみ。

Matrix3d.java
package org.MyNet2;

import java.util.ArrayList;
import java.util.Random;
import org.MyNet2.*;

/**
 * Class for three dimentional matrix.
 */
public class Matrix3d {
    /** Value of this matrix. */
    public ArrayList<Matrix> matrix;
    /** Shape of this matrix. */
    public int[] shape;

    /**
     * Exit this program.
     * @param msg Message for printing.
     */
    protected void exit(String msg){
        System.out.println(msg);
        System.exit(-1);
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     */
    public Matrix3d(int[] shape){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix>();

        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix(this.shape[1], this.shape[2]));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param num Number to fill.
     */
    public Matrix3d(int[] shape, double num){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix>();

        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix(this.shape[1], this.shape[2], num));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param rand Random instance.
     */
    public Matrix3d(int[] shape, Random rand){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix>();

        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix(this.shape[1], this.shape[2], rand));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param rand Random instance.
     * @param min Number of min for range.
     * @param max Number of max for range.
     */
    public Matrix3d(int[] shape, Random rand, double min, double max){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix>();

        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix(this.shape[1], this.shape[2], rand, min, max));
        }
    }

    @Override
    public Matrix3d clone(){
        Matrix3d rtn = new Matrix3d(this.shape);

        for (int i = 0; i < this.shape[0]; i++){
            rtn.matrix.set(i, this.matrix.get(i).clone());
        }

        return rtn;
    }

    @Override
    public String toString(){
        String rtn = "[";
        int j;

        for (int i = 0; i < this.shape[0]; i++){
            if (i == 0){
                rtn += "[";
            }else{
                rtn += "\n [";
            }
            j = 0;
            for (double[] ele: this.matrix.get(i).matrix){
                if (j == 0){
                    rtn += "[";
                }else{
                    rtn += "\n  [";
                }
                j++;
                for (double num: ele){
                    if (num < 0){
                        rtn += String.format("%.4f ", num);
                    }else{
                        rtn += String.format(" %.4f ", num);
                    }
                }
                rtn += "]";
            }
            rtn += "]";
        }

        rtn += "]\n";

        return rtn;
    }
}

Matrix4d

四次元行列。
三次元同様、コンストラクタとcloneメソッドのみ。

Matrix4d.java
package org.MyNet2;

import java.util.ArrayList;
import java.util.Random;

/**
 * Class for four dimentional matrix.
 */
public class Matrix4d {
    /** Value of this matrix. */
    public ArrayList<Matrix3d> matrix;
    /** Shape of this matrix. */
    public int[] shape;

    /**
     * Exit this program.
     * @param msg Message for printing.
     */
    protected void exit(String msg){
        System.out.println(msg);
        System.exit(-1);
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     */
    public Matrix4d(int[] shape){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix3d>();

        int[] shapeElement = {this.shape[1], this.shape[2], this.shape[3]};
        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix3d(shapeElement));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param num Number to fill.
     */
    public Matrix4d(int[] shape, double num){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix3d>();

        int[] shapeElement = {this.shape[1], this.shape[2], this.shape[3]};
        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix3d(shapeElement, num));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param rand Random instance.
     */
    public Matrix4d(int[] shape, Random rand){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix3d>();

        int[] shapeElement = {this.shape[1], this.shape[2], this.shape[3]};
        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix3d(shapeElement, rand));
        }
    }

    /**
     * Constructor for this class.
     * @param shape Shape of this matrix.
     * @param rand Random instance.
     * @param min Number of min for range.
     * @param max Number of max for range.
     */
    public Matrix4d(int[] shape, Random rand, double min, double max){
        this.shape = shape.clone();
        this.matrix = new ArrayList<Matrix3d>();

        int[] shapeElement = {this.shape[1], this.shape[2], this.shape[3]};
        for (int i = 0; i < this.shape[0]; i++){
            this.matrix.add(new Matrix3d(shapeElement, rand, min, max));
        }
    }

    @Override
    public Matrix4d clone(){
        Matrix4d rtn = new Matrix4d(this.shape);

        for (int i = 0; i < this.shape[0]; i++){
            rtn.matrix.set(i, this.matrix.get(i).clone());
        }

        return rtn;
    }

    @Override
    public String toString(){
        String rtn = "[";
        int j, k;

        for (int i = 0; i < this.shape[0]; i++){
            if (i == 0){
                rtn += "[";
            }else{
                rtn += "\n [";
            }
            j = 0;
            for (Matrix matrix: this.matrix.get(i).matrix){
                if (j == 0){
                    rtn += "[";
                }else{
                    rtn += "\n  [";
                }
                j++;
                k = 0;
                for (double[] ele: matrix.matrix){
                    if (k == 0){
                        rtn += "[";
                    }else{
                        rtn += "\n   [";
                    }
                    k++;
                    for (double num: ele){
                        if (num < 0){
                            rtn += String.format("%.4f ", num);
                        }else{
                            rtn += String.format(" %.4f ", num);
                        }
                    }
                    rtn += "]";
                }
                rtn += "]";
            }
            rtn += "]";
        }

        rtn += "]\n";

        return rtn;
    }
}

Conv

いよいよ畳み込み層を作ります。
コンストラクタでは入力チャネル数と出力チャネル数(重みのカーネル数)、重み行列のサイズのみを指定します。
入力チャネル数を指定しないコンストラクタについては、作る予定もありますがとりあえずなしで行きます。

また、パディングやストライドについてですが、どうせオセロでは使用しないのでここでは実装しないことにします。これはプーリング層も同様です。

Conv.java
package org.MyNet2.layer;

import java.util.Random;
import org.MyNet2.*;
import org.MyNet2.actFunc.*;

/**
 * Class for convolution layer.
 */
public class Conv extends Layer {
    /** The list of weight for this layer */
    public Matrix4d w;
    /** Number of chanel. */
    public int channelNum;
    /** Number of kernel. */
    public int kernelNum;
    /** Row of weight matrix. */
    public int wRow;
    /** Column of weight matrix. */
    public int wCol;

    /**
     * Constructor for this class.
     * @param channelNum Number of channel.
     * @param kernelNum Number of kernel.
     * @param wRow Row of weight matrix.
     * @param wCol Column of input weight matrix.
     * @param afType Type of activation fucntion for this layer.
     */
    public Conv(int channelNum, int kernelNum, int wRow, int wCol, AFType afType){
        this.setup(channelNum, kernelNum, wRow, wCol, 0, afType);
    }

    /**
     * Constructor for this class.
     * @param channelNum Number of channel.
     * @param kernelNum Number of kernel.
     * @param wRow Row of input weight matrix.
     * @param wCol Column of input weight matrix.
     * @param seed Number of seed for random class.
     * @param afType Type of activation fucntion for this layer.
     */
    public Conv(int channelNum, int kernelNum, int wRow, int wCol, long seed, AFType afType){
        this.setup(channelNum, kernelNum, wRow, wCol, seed, afType);
    }

    /**
     * Construct instead of constructor.
     * @param channelNum Number of channel.
     * @param kernelNum Number of kernel.
     * @param wRow Row of input weight matrix.
     * @param wCol Column of input weight matrix.
     * @param seed Number of seed for random class.
     * @param afType Type of activation fucntion for this layer.
     */
    protected void setup(int channelNum, int kernelNum, int wRow, int wCol, long seed, AFType afType){
        this.channelNum = channelNum;
        this.kernelNum = kernelNum;
        this.wRow = wRow;
        this.wCol = wCol;

        this.w = new Matrix4d(new int[]{kernelNum, channelNum, wRow, wCol}, new Random(seed));

        this.afType = afType;
        switch (afType){
        case SIGMOID:
            this.actFunc = new Sigmoid();
            break;
        case RELU:
            this.actFunc = new ReLU();
            break;
        case TANH:
            this.actFunc = new Tanh();
            break;
        case LINEAR:
            this.actFunc = new Linear();
            break;
        default:
            System.out.println("ERROR: The specified activation function is wrong");
            System.exit(-1);
        }
        this.actFuncName = this.actFunc.toString();
    }

    /**
     * Doing forward propagation.
     * @param in input matrix.
     * @return Matrix4d instance of output.
     */
    @Override
    public Matrix4d forward(Matrix4d in){
        int batchSize = in.shape[0];
        int rtnRow = in.shape[2] - this.w.shape[2] + 1;
        int rtnCol = in.shape[3] - this.w.shape[3] + 1;

        Matrix4d rtn = new Matrix4d(
            new int[]{
                batchSize,
                this.kernelNum,
                rtnRow,
                rtnCol
            }
        );
        for (int b = 0; b < batchSize; b++){
            for (int k = 0; k < this.kernelNum; k++){
                for (int i = 0; i < rtnRow; i++){
                    for (int j = 0; j < rtnCol; j++){
                        for (int c = 0; c < this.channelNum; c++){
                            for (int p = 0; p < this.wRow; p++){
                                for (int q = 0; q < this.wCol; q++){
                                    rtn.matrix.get(b).matrix.get(k).matrix[i][j] += 
                                        this.w.matrix.get(k).matrix.get(c).matrix[p][q] * in.matrix.get(b).matrix.get(c).matrix[i+p][j+q];
                                }
                            }
                        }

                        rtn.matrix.get(b).matrix.set(k, this.actFunc.calc(rtn.matrix.get(b).matrix.get(k)));
                    }
                }
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        String str = String.format(
            "----------------------------------------------------------------\n"
            + "Convolution\n"
            + "channels: %d, kernels: %d, conv size: %dx%d",
            this.channelNum, this.kernelNum, this.wRow, this.wCol
        );

        return str;
    }
}

テスト

2x3x6x6の入力行列を、3x42x2の重み行列で畳み込んでみます。
入力行列と重みは、それぞれ以下のプログラム内で初期化しています。実際には重みは乱数で初期化します。

ConvTest.java
import javax.management.relation.RelationSupport;

import org.MyNet2.*;
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;

public class ConvTest {
    public static void main(String[] str){
        // 入力行列作成
        Matrix4d in = new Matrix4d(new int[]{2, 3, 6, 6});
        Matrix m = new Matrix(6, 6);

        for (int i = 0; i < m.row; i++){
            for (int j = 0; j < m.col; j++){
                m.matrix[i][j] = i * 0.1 - j * 0.2;
            }
        }

        for (int i = 0; i < in.shape[0]; i++){
            for (int j = 0; j < in.shape[1]; j++){
                in.matrix.get(i).matrix.set(j, m.add(i * 0.1 - j * 0.2));
            }
        }

        // 入力行列表示
        System.out.println(in);
        System.out.println();

        // 畳み込み層作成
        Conv conv = new Conv(3, 4, 2, 2, AFType.RELU);
        // 重み初期化と重み表示(実際に使用する際はしない)
        for (int i = 0; i < conv.kernelNum; i++){
            for (int j = 0; j < conv.channelNum; j++){
                conv.w.matrix.get(i).matrix.set(j, new Matrix(2, 2, i * 0.1 - j * 0.2));
                System.out.println(conv.w.matrix.get(i).matrix.get(j));
            }
            System.out.println();
        }
        // 畳み込み層順伝播
        System.out.println(conv.forward(in));
    }
}

実行結果はこちら。
入力行列、重み行列、出力行列の順に出力されます。

[[[[ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000  0.0000 -0.2000 -0.4000 -0.6000 ]
   [ 0.5000  0.3000  0.1000 -0.1000 -0.3000 -0.5000 ]]
  [[-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]]
  [[-0.4000 -0.6000 -0.8000 -1.0000 -1.2000 -1.4000 ]
   [-0.3000 -0.5000 -0.7000 -0.9000 -1.1000 -1.3000 ]
   [-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]]]
 [[[ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000  0.0000 -0.2000 -0.4000 -0.6000 ]
   [ 0.5000  0.3000  0.1000 -0.1000 -0.3000 -0.5000 ]
   [ 0.6000  0.4000  0.2000 -0.0000 -0.2000 -0.4000 ]]
  [[-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000 -0.0000 -0.2000 -0.4000 -0.6000 ]]
  [[-0.3000 -0.5000 -0.7000 -0.9000 -1.1000 -1.3000 ]
   [-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000 -0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]]]]


[[ 0.0000  0.0000 ]
 [ 0.0000  0.0000 ]]

[[-0.2000 -0.2000 ]
 [-0.2000 -0.2000 ]]

[[-0.4000 -0.4000 ]
 [-0.4000 -0.4000 ]]


[[ 0.1000  0.1000 ]
 [ 0.1000  0.1000 ]]

[[-0.1000 -0.1000 ]
 [-0.1000 -0.1000 ]]

[[-0.3000 -0.3000 ]
 [-0.3000 -0.3000 ]]


[[ 0.2000  0.2000 ]
 [ 0.2000  0.2000 ]]

[[ 0.0000  0.0000 ]
 [ 0.0000  0.0000 ]]

[[-0.2000 -0.2000 ]
 [-0.2000 -0.2000 ]]


[[ 0.3000  0.3000 ]
 [ 0.3000  0.3000 ]]

[[ 0.1000  0.1000 ]
 [ 0.1000  0.1000 ]]

[[-0.1000 -0.1000 ]
 [-0.1000 -0.1000 ]]


[[[[ 0.9200  1.4000  1.8800  2.3600  2.8400 ]
   [ 0.6800  1.1600  1.6400  2.1200  2.6000 ]
   [ 0.4400  0.9200  1.4000  1.8800  2.3600 ]
   [ 0.2000  0.6800  1.1600  1.6400  2.1200 ]
   [ 0.0000  0.4400  0.9200  1.4000  1.8800 ]]
  [[ 0.6200  0.8600  1.1000  1.3400  1.5800 ]
   [ 0.5000  0.7400  0.9800  1.2200  1.4600 ]
   [ 0.3800  0.6200  0.8600  1.1000  1.3400 ]
   [ 0.2600  0.5000  0.7400  0.9800  1.2200 ]
   [ 0.1400  0.3800  0.6200  0.8600  1.1000 ]]
  [[ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]]
  [[ 0.0200  0.0000  0.0000  0.0000  0.0000 ]
   [ 0.1400  0.0000  0.0000  0.0000  0.0000 ]
   [ 0.2600  0.0200  0.0000  0.0000  0.0000 ]
   [ 0.3800  0.1400  0.0000  0.0000  0.0000 ]
   [ 0.5000  0.2600  0.0200  0.0000  0.0000 ]]]
 [[[ 0.6800  1.1600  1.6400  2.1200  2.6000 ]
   [ 0.4400  0.9200  1.4000  1.8800  2.3600 ]
   [ 0.2000  0.6800  1.1600  1.6400  2.1200 ]
   [ 0.0000  0.4400  0.9200  1.4000  1.8800 ]
   [ 0.0000  0.2000  0.6800  1.1600  1.6400 ]]
  [[ 0.5000  0.7400  0.9800  1.2200  1.4600 ]
   [ 0.3800  0.6200  0.8600  1.1000  1.3400 ]
   [ 0.2600  0.5000  0.7400  0.9800  1.2200 ]
   [ 0.1400  0.3800  0.6200  0.8600  1.1000 ]
   [ 0.0200  0.2600  0.5000  0.7400  0.9800 ]]
  [[ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]
   [ 0.3200  0.3200  0.3200  0.3200  0.3200 ]]
  [[ 0.1400  0.0000  0.0000  0.0000  0.0000 ]
   [ 0.2600  0.0200  0.0000  0.0000  0.0000 ]
   [ 0.3800  0.1400  0.0000  0.0000  0.0000 ]
   [ 0.5000  0.2600  0.0200  0.0000  0.0000 ]
   [ 0.6200  0.3800  0.1400  0.0000  0.0000 ]]]]

Excelで同じように畳み込んだところ、以下のようになりました。

image.png

正しく畳み込めていることが分かります。

Pooling

プーリング層。
平均プーリングとマックスプーリングがありますが、これらの親クラスとなるPoolingクラスを作成します。
四次元データを全結合層へ渡すためのflattenメソッドを備えています。

Poolig.java
package org.MyNet2.layer;

import org.MyNet2.*;

/**
 * Class for pooling layer.
 */
public class Pooling extends Layer {
    /** Pooling matrix size. */
    public int poolSize;
    /** Number of channel. */
    public int channelNum;

    public Pooling(){
        ;
    }

    /**
     * Flatten 4 dimentional matrix to 2 dimentional matrix.
     * @param in 4 dimentional matrix.
     * @return 2 dimentional matrix.
     */
    public Matrix flatten(Matrix4d in){
        Matrix rtn = new Matrix(in.shape[0], in.shape[1]*in.shape[2]*in.shape[3]);

        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < in.shape[1]; j++){
                for (int k = 0; k < in.shape[2]; k++){
                    for (int l = 0; l < in.shape[3]; l++){
                        rtn.matrix[i][j*in.shape[1]+k*in.shape[2]+l] = in.matrix.get(i).matrix.get(j).matrix[k][l];
                    }
                }
            }
        }

        return rtn;
    }

    /**
     * Doing forward propagation.
     * @param in input matrix.
     * @return Matrix instance of output.
     */
    @Override
    public Matrix4d forward(Matrix4d in){
        return in.clone();
    }

    /**
     * Doing back propagation.
     */
    @Override
    public void back(){
        ;
    }
}

MaxPooling

マックスプーリングクラス。
チャネル数とプーリング行列の大きささえ指定すれば作成できます。

MaxPoolig.java
package org.MyNet2.layer;

import org.MyNet2.*;

/**
 * Class for max pooling layer.
 */
public class MaxPooling extends Pooling {
    /**
     * Constructor fot this challs.
     * @param channelNum Number of channel.
     * @param poolSize Size of pooling matrix.
     */
    public MaxPooling(int channelNum, int poolSize){
        this.channelNum = channelNum;
        this.poolSize = poolSize;
    }

    /**
     * Doing forward propagation.
     * @param in input matrix.
     * @return Matrix instance of output.
     */
    @Override
    public Matrix4d forward(Matrix4d in){
        int batchSize = in.shape[0];
        int rtnRow = in.shape[2] / this.poolSize;
        int rtnCol = in.shape[3] / this.poolSize;
        double num, max;

        Matrix4d rtn = new Matrix4d(
            new int[]{
                batchSize,
                this.channelNum,
                rtnRow,
                rtnCol
            }
        );
        for (int b = 0; b < batchSize; b++){
            for (int k = 0; k < this.channelNum; k++){
                for (int i = 0; i < rtnRow; i++){
                    for (int j = 0; j < rtnCol; j++){
                        max = -100.0;
                        for (int p = 0; p < this.poolSize; p++){
                            for (int q = 0; q < this.poolSize; q++){
                                num = in.matrix.get(b).matrix.get(k).matrix[this.poolSize*i+p][this.poolSize*j+q];

                                if (num > max){
                                    max = num;
                                }
                            }
                        }

                        rtn.matrix.get(b).matrix.get(k).matrix[i][j] = max;
                    }
                }
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        String str = String.format(
            "----------------------------------------------------------------\n"
            + "MaxPooling\n"
            + "channels: %d, pool size: %d", this.channelNum, this.poolSize
        );

        return str;
    }
}
テスト

2x3x6x6の四次元行列を2x2でマックスプーリングします。

MaxPoolingTest.java
import org.MyNet2.*;
import org.MyNet2.layer.*;

public class MaxPoolingTest {
    public static void main(String[] str){
        // 入力行列作成
        Matrix4d in = new Matrix4d(new int[]{2, 3, 6, 6});
        Matrix m = new Matrix(6, 6);

        for (int i = 0; i < m.row; i++){
            for (int j = 0; j < m.col; j++){
                m.matrix[i][j] = i * 0.1 - j * 0.2;
            }
        }

        for (int i = 0; i < in.shape[0]; i++){
            for (int j = 0; j < in.shape[1]; j++){
                in.matrix.get(i).matrix.set(j, m.add(i * 0.1 - j * 0.2));
            }
        }

        // 入力行列表示
        System.out.println(in);
        System.out.println();

        // プーリング層作成
        MaxPooling pool = new MaxPooling(3, 2);
        // プーリング層順伝播
        System.out.println(pool.forward(in));
    }
}

実行結果はこちら。

[[[[ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000  0.0000 -0.2000 -0.4000 -0.6000 ]
   [ 0.5000  0.3000  0.1000 -0.1000 -0.3000 -0.5000 ]]
  [[-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]]
  [[-0.4000 -0.6000 -0.8000 -1.0000 -1.2000 -1.4000 ]
   [-0.3000 -0.5000 -0.7000 -0.9000 -1.1000 -1.3000 ]
   [-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]]]
 [[[ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000  0.0000 -0.2000 -0.4000 -0.6000 ]
   [ 0.5000  0.3000  0.1000 -0.1000 -0.3000 -0.5000 ]
   [ 0.6000  0.4000  0.2000 -0.0000 -0.2000 -0.4000 ]]
  [[-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000  0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]
   [ 0.3000  0.1000 -0.1000 -0.3000 -0.5000 -0.7000 ]
   [ 0.4000  0.2000 -0.0000 -0.2000 -0.4000 -0.6000 ]]
  [[-0.3000 -0.5000 -0.7000 -0.9000 -1.1000 -1.3000 ]
   [-0.2000 -0.4000 -0.6000 -0.8000 -1.0000 -1.2000 ]
   [-0.1000 -0.3000 -0.5000 -0.7000 -0.9000 -1.1000 ]
   [ 0.0000 -0.2000 -0.4000 -0.6000 -0.8000 -1.0000 ]
   [ 0.1000 -0.1000 -0.3000 -0.5000 -0.7000 -0.9000 ]
   [ 0.2000 -0.0000 -0.2000 -0.4000 -0.6000 -0.8000 ]]]]


[[[[ 0.1000 -0.3000 -0.7000 ]
   [ 0.3000 -0.1000 -0.5000 ]
   [ 0.5000  0.1000 -0.3000 ]]
  [[-0.1000 -0.5000 -0.9000 ]
   [ 0.1000 -0.3000 -0.7000 ]
   [ 0.3000 -0.1000 -0.5000 ]]
  [[-0.3000 -0.7000 -1.1000 ]
   [-0.1000 -0.5000 -0.9000 ]
   [ 0.1000 -0.3000 -0.7000 ]]]
 [[[ 0.2000 -0.2000 -0.6000 ]
   [ 0.4000  0.0000 -0.4000 ]
   [ 0.6000  0.2000 -0.2000 ]]
  [[ 0.0000 -0.4000 -0.8000 ]
   [ 0.2000 -0.2000 -0.6000 ]
   [ 0.4000 -0.0000 -0.4000 ]]
  [[-0.2000 -0.6000 -1.0000 ]
   [ 0.0000 -0.4000 -0.8000 ]
   [ 0.2000 -0.2000 -0.6000 ]]]]

Excelで同じようにマックスプーリングしたところ、以下の結果になりました。

image.png

正しくプーリング出来ていることが分かります。

フルバージョン

次回は

層をまとめて扱うネットワーククラスを作成します。

次回

参考文献

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0