今回の目標
ネットワーククラスを作成する。
ここから本編
現在のディレクトリ構成
MyNet2
├── Matrix.java
├── actFunc
│ └── 活性化関数クラス
├── layer
│ └── 層クラス
├── network
│ └── ネットワーククラス
└── tests
└── テストクラス
修正点
全ての層で共通するデータ型を使用するため、ConvとMaxPoolingのforwardメソッドを変更しました。
また、二次元と四次元の相互変換のため、MatrixクラスにtoMatrix4dメソッドを、Matrix4dクラスにflattenメソッドをそれぞれ追加しました。
/**
* Doing forward propagation.
* @param in input matrix.
* @return Matrix instance of output.
*/
@Override
public Matrix forward(Matrix in){
int kMult = this.outRow * this.outCol;
int cWMult = this.wRow * this.wCol;
int cInMult = this.inRow * this.inCol;
Matrix rtn = new Matrix(in.row, this.kernelNum * this.outRow * this.outCol);
for (int b = 0; b < in.row; b++){
for (int k = 0; k < this.kernelNum; k++){
for (int i = 0; i < this.outRow; i++){
for (int j = 0; j < this.outCol; j++){
for (int c = 0; c < this.channelNum; c++){
for (int p = 0; p < this.wRow; p++){
for (int q = 0; q < this.wCol; q++){
rtn.matrix[b][k*kMult + i*this.outCol + j] +=
this.w.matrix[k][c*cWMult + p*this.wCol + q]
* in.matrix[b][c*cInMult + (i+p)*this.inCol + (j+q)];
System.out.printf("%d,%d,%d,%d,%d,%d,%d\t%2.4f,%2.4f\n", b,k,i,j,c,p,q, this.w.matrix[k][c*cWMult + p*this.wCol + q], in.matrix[b][c*cInMult + (i+p)*this.inCol + (j+q)]);
}
}
}
}
}
}
}
return this.actFunc.calc(rtn);
}
/**
* Doing forward propagation.
* @param in input matrix.
* @return Matrix instance of output.
*/
@Override
public Matrix forward(Matrix in){
int itr;
double num, max;
int kInMult = this.inRow * this.inCol;
int kOutMult = this.outRow * this.outCol;
int poolSizeI, poolSizej;
Matrix rtn = new Matrix(in.row, this.channelNum * this.outRow * this.outCol);
for (int b = 0; b < in.row; b++){
for (int k = 0; k < this.channelNum; k++){
for (int i = 0; i < this.outRow; i++){
for (int j = 0; j < this.outCol; j++){
max = -100.0;
poolSizeI = this.poolSize * i;
poolSizej = this.poolSize * j;
for (int p = 0; p < this.poolSize; p++){
for (int q = 0; q < this.poolSize; q++){
itr = k*kInMult + (poolSizeI+p)*this.inCol + (poolSizej+q);
num = in.matrix[b][itr];
if (num > max){
max = num;
}
}
}
rtn.matrix[b][k*kOutMult + i*this.outCol + j] = max;
}
}
}
}
return rtn;
}
/**
* Make a 4 dimentional matrix from this 2 dimentional matrix.
* @param shape Shape of 4 dimentional matrix.
* @return 2 dimentional matrix.
*/
public Matrix4d toMatrix4d(int[] shape){
if (shape.length != 4){
this.exit("shape is wrong.");
}else if(shape[0] != this.row){
this.exit("row number is wrong.");
}
int jMult = shape[2] * shape[3];
Matrix4d rtn = new Matrix4d(shape);
for (int i = 0; i < shape[0]; i++){
for (int j = 0; j < shape[1]; j++){
for (int k = 0; k < shape[2]; k++){
for (int l = 0; l < shape[3]; l++){
rtn.matrix.get(i).matrix.get(j).matrix[k][l] = this.matrix[i][j*jMult + k*shape[3] + l];
}
}
}
}
return rtn;
}
/**
* Make a 4 dimentional matrix from this 2 dimentional matrix.
* @param shape0 Shape of 4 dimentional matrix.
* @param shape1 Shape of 4 dimentional matrix.
* @param shape2 Shape of 4 dimentional matrix.
* @param shape3 Shape of 4 dimentional matrix.
* @return 2 dimentional matrix.
*/
public Matrix4d toMatrix4d(int shape0, int shape1, int shape2, int shape3){
if(shape0 != this.row){
this.exit("row number is wrong.");
}
int jMult = shape2 * shape3;
Matrix4d rtn = new Matrix4d(new int[]{shape0, shape1, shape2, shape3});
for (int i = 0; i < shape0; i++){
for (int j = 0; j < shape1; j++){
for (int k = 0; k < shape2; k++){
for (int l = 0; l < shape3; l++){
rtn.matrix.get(i).matrix.get(j).matrix[k][l] = this.matrix[i][j*jMult + k*shape3 + l];
}
}
}
}
return rtn;
}
/**
* Flatten 4 dimentional matrix to 2 dimentional matrix.
* @return 2 dimentional matrix.
*/
public Matrix flatten(){
Matrix rtn = new Matrix(this.shape[0], this.shape[1]*this.shape[2]*this.shape[3]);
int jMult = this.shape[2] * this.shape[3];
for (int i = 0; i < rtn.row; i++){
for (int j = 0; j < this.shape[1]; j++){
for (int k = 0; k < this.shape[2]; k++){
for (int l = 0; l < this.shape[3]; l++){
rtn.matrix[i][j*jMult + k*this.shape[3] + l] = this.matrix.get(i).matrix.get(j).matrix[k][l];
}
}
}
}
return rtn;
}
また、それに合わせてテストクラスの内容も若干ではありますが変更しました。本題ではないので興味のある方だけ見てください。
テストクラスの変更
import javax.management.relation.RelationSupport;
import org.MyNet2.*;
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;
public class ConvTest {
public static void main(String[] str){
Matrix4d in = new Matrix4d(new int[]{2, 3, 6, 6});
Matrix m = new Matrix(6, 6);
for (int i = 0; i < m.row; i++){
for (int j = 0; j < m.col; j++){
m.matrix[i][j] = i * 0.1 - j * 0.2;
}
}
for (int i = 0; i < in.shape[0]; i++){
for (int j = 0; j < in.shape[1]; j++){
in.matrix.get(i).matrix.set(j, m.add(i * 0.1 - j * 0.2));
}
}
System.out.println(in);
System.out.println();
Conv conv = new Conv(3, 4, new int[]{6, 6}, new int[]{2, 2}, AFType.RELU);
Matrix4d w = new Matrix4d(new int[]{4, 3, 2, 2});
for (int i = 0; i < conv.kernelNum; i++){
for (int j = 0; j < conv.channelNum; j++){
w.matrix.get(i).matrix.set(j, new Matrix(2, 2, i * 0.1 - j * 0.2));
System.out.println(w.matrix.get(i).matrix.get(j));
}
System.out.println();
}
conv.w = w.flatten();
Matrix out = conv.forward(in.flatten());
System.out.println(out.toMatrix4d(2, 4, 5, 5));
}
}
import org.MyNet2.*;
import org.MyNet2.layer.*;
public class MaxPoolingTest {
public static void main(String[] str){
Matrix4d in = new Matrix4d(new int[]{2, 3, 6, 6});
Matrix m = new Matrix(6, 6);
for (int i = 0; i < m.row; i++){
for (int j = 0; j < m.col; j++){
m.matrix[i][j] = i * 0.1 - j * 0.2;
}
}
for (int i = 0; i < in.shape[0]; i++){
for (int j = 0; j < in.shape[1]; j++){
in.matrix.get(i).matrix.set(j, m.add(i * 0.1 - j * 0.2));
}
}
System.out.println(in);
System.out.println();
MaxPooling pool = new MaxPooling(3, new int[]{6, 6}, 2);
Matrix out = pool.forward(in.flatten());
System.out.println(out.toMatrix4d(2, 3, 3, 3));
}
}
また、各層のtoStringメソッドを変更しました。
toString
@Override
public String toString(){
String str = String.format(
"----------------------------------------------------------------\n"
+ "Dense\nact: %s\n"
+ "%d => %d", this.actFuncName, this.inNum-1, this.nodesNum
);
return str;
}
@Override
public String toString(){
String str = String.format(
"----------------------------------------------------------------\n"
+ "MaxPooling\nact: null\n"
+ "%d, %d, %d => (%d, %d) => %d, %d, %d",
this.channelNum, this.inRow, this.inCol, this.poolSize, this.poolSize, this.channelNum, this.outRow, this.outCol
);
return str;
}
@Override
public String toString(){
String str = String.format(
"----------------------------------------------------------------\n"
+ "Convolution\nact: %s\n"
+ "%d, %d, %d => (%d, %d) => %d, %d, %d",
this.actFuncName, this.channelNum, this.inRow, this.inCol, this.wRow, this.wCol, this.kernelNum, this.outRow, this.outCol
);
return str;
}
ConvクラスとMaxPoolingクラスに新しいコンストラクタを追加したほか、若干の変更を行いました。
/**
* Constructor for this class.
* @param kernelNum Number of kernel.
* @param wShape Shape of weight matrix.
* @param afType Type of activation fucntion for this layer.
*/
public Conv(int kernelNum, int[] wShape, AFType afType){
if (wShape.length != 2){
this.exit("wShape length is wrong");
}
this.name = "Conv";
this.kernelNum = kernelNum;
this.wRow = wShape[0];
this.wCol = wShape[1];
this.afType = afType;
}
/**
* Constructor for this class.
* @param channelNum Number of channel.
* @param kernelNum Number of kernel.
* @param inShape Shape of input.
* @param wShape Shape of weight matrix.
* @param afType Type of activation fucntion for this layer.
*/
public Conv(int channelNum, int kernelNum, int[] inShape, int[] wShape, AFType afType){
this.setup(channelNum, kernelNum, inShape, wShape, afType, 0);
}
/**
* Constructor for this class.
* @param channelNum Number of channel.
* @param kernelNum Number of kernel.
* @param inShape Shape of input.
* @param wShape Shape of weight matrix.
* @param seed Number of seed for random class.
* @param afType Type of activation fucntion for this layer.
*/
public Conv(int channelNum, int kernelNum, int[] inShape, int[] wShape, AFType afType, long seed){
this.setup(channelNum, kernelNum, inShape, wShape, afType, seed);
}
/**
* Construct instead of constructor.
* @param channelNum Number of channel.
* @param kernelNum Number of kernel.
* @param inShape Shape of input.
* @param wShape Shape of weight matrix.
* @param seed Number of seed for random class.
* @param afType Type of activation fucntion for this layer.
*/
public void setup(int channelNum, int kernelNum, int[] inShape, int[] wShape, AFType afType, long seed){
if (inShape.length != 2){
this.exit("inShape length is wrong.");
}else if (wShape.length != 2){
this.exit("wShape length is wrong.");
}
this.name = "Conv";
this.channelNum = channelNum;
this.kernelNum = kernelNum;
this.inRow = inShape[0];
this.inCol = inShape[1];
this.outRow = inShape[0] - wShape[0] + 1;
this.outCol = inShape[1] - wShape[1] + 1;
this.wRow = wShape[0];
this.wCol = wShape[1];
this.w = new Matrix(kernelNum, channelNum * wRow * wCol, new Random(seed));
this.afType = afType;
switch (afType){
case SIGMOID:
this.actFunc = new Sigmoid();
break;
case RELU:
this.actFunc = new ReLU();
break;
case TANH:
this.actFunc = new Tanh();
break;
case LINEAR:
this.actFunc = new Linear();
break;
default:
System.out.println("ERROR: The specified activation function is wrong");
System.exit(-1);
}
this.actFuncName = this.actFunc.toString();
}
/**
* Constructor for this class.
* @param poolSize Size of pooling matrix.
*/
public MaxPooling(int poolSize){
this.name = "MaxPooling";
this.poolSize = poolSize;
}
/**
* Constructor fot this class.
* @param channelNum Number of channel.
* @param inShape Shape of input.
* @param poolSize Size of pooling matrix.
*/
public MaxPooling(int channelNum, int[] inShape, int poolSize){
this.setup(channelNum, inShape, poolSize);
}
/**
* Construct instead of constructor.
* @param channelNum Number of channel.
* @param inShape Shape of input.
* @param poolSize Size of pooling matrix.
*/
public void setup(int channelNum, int[] inShape, int poolSize){
if (inShape.length != 2){
this.exit("inShape length is wrong.");
}
this.name = "MaxPooling";
this.channelNum = channelNum;
this.inRow = inShape[0];
this.inCol = inShape[1];
this.outRow = inShape[0] / poolSize;
this.outCol = inShape[1] / poolSize;
this.poolSize = poolSize;
}
Network
ネットワーククラスを作成します。
このクラスでは、各層のインスタンスを受け取ります。が、実際に使用する際の手間を減らすために、各層の入力情報は初期化されていない状態です。そのためコンストラクタでは、全結合層なら入力数、畳み込み・プーリング層では入力チャネル数、行列サイズをここで指定する必要があります。
かつ、全結合層のみのネットワークと畳み込み層などを含むネットワークで異なるコンストラクタを用います。
全結合層のみのコンストラクタは以下の通り。
現在の層のノード数を次の層の入力数として初期化しています。
/**
* Constructor for this class.
* Only dense layer.
* @param seed Seed of random.
* @param inNum Number of input.
* @param layers Each layers.
*/
public Network(int seed, int inNum, Layer ... layers){
this.setup(seed, inNum, layers);
}
/**
* Constructor for this class.
* Only dense layer.
* @param inNum Number of input.
* @param layers Each layers.
*/
public Network(int inNum, Layer ... layers){
this.setup(0, inNum, layers);
}
/**
* Construct instead of constructor.
* Only dense layer.
* @param seed Seed of random.
* @param inNum Number of input.
* @param layers Each layers.
*/
protected void setup(int seed, int inNum, Layer ... layers){
this.layers = layers;
int nextLayerInNum = inNum;
for (Layer layer: this.layers){
switch (layer.name){
case "Dense":
layer.setup(nextLayerInNum, layer.nodesNum, layer.afType, seed);
nextLayerInNum = layer.nodesNum;
break;
default:
this.exit("layer error");
}
}
}
次に、畳み込み層を含むネットワークのコンストラクタ。
畳み込み層ならカーネル数、プーリング層なら入力チャネル数を次の層の入力チャネル数とし、それぞれの出力行列サイズを次の層の入力行列サイズとして指定しています。
プーリング層から全結合層へ移る際は、チャネル数x行列サイズで入力数を計算しています。
その層の重み行列サイズやプールサイズなどはあらかじめその層のコンストラクタで指定しています。
/**
* Construct of this class.
* Contain convulation and pooling layer.
* @param channeNum Numer of channel.
* @param inRow Row of input matrix.
* @param inCol Column of input matrix.
* @param inNum Number of input.
* @param layers Each layers.
*/
public Network(int channelNum, int inRow, int inCol, Layer ... layers){
this.setup(0, channelNum, inRow, inCol, layers);
}
/**
* Construct of this class.
* Contain convulation and pooling layer.
* @param seed Number of seed.
* @param channeNum Numer of channel.
* @param inRow Row of input matrix.
* @param inCol Column of input matrix.
* @param inNum Number of input.
* @param layers Each layers.
*/
public Network(int seed, int channelNum, int inRow, int inCol, Layer ... layers){
this.setup(seed, channelNum, inRow, inCol, layers);
}
/**
* Construct instead of constructor.
* Contain convulation and pooling layer.
* @param channeNum Numer of channel.
* @param inRow Row of input matrix.
* @param inCol Column of input matrix.
* @param inNum Number of input.
* @param layers Each layers.
*/
protected void setup(int seed, int channelNum, int inRow, int inCol, Layer ... layers){
this.layers = layers;
int nextLayerChannelNum = channelNum;
int nextLayerInRow = inRow;
int nextLayerInCol = inCol;
int nextLayerInNum = 0;
for (Layer layer: this.layers){
switch (layer.name){
case "Dense":
layer.setup(nextLayerInNum, layer.nodesNum, layer.afType, seed);
nextLayerInNum = layer.nodesNum;
break;
case "Conv":
layer.setup(
nextLayerChannelNum,
layer.kernelNum,
new int[]{nextLayerInRow, nextLayerInCol},
new int[]{layer.wRow, layer.wCol},
layer.afType,
seed
);
nextLayerChannelNum = layer.kernelNum;
nextLayerInRow = layer.outRow;
nextLayerInCol = layer.outCol;
nextLayerInNum = layer.kernelNum * layer.outRow * layer.outCol;
break;
case "MaxPooling":
layer.setup(
nextLayerChannelNum,
new int[]{nextLayerInRow, nextLayerInCol},
layer.poolSize
);
nextLayerChannelNum = layer.channelNum;
nextLayerInRow = layer.outRow;
nextLayerInCol = layer.outCol;
nextLayerInNum = layer.channelNum * layer.outRow * layer.outCol;
break;
default:
this.exit("layer error");
}
}
}
順方向計算メソッド。
畳み込み層とプーリング層でもMatrixクラスを使っているため、全ての層で、同じ方法で順方向計算ができます。
/**
* Doing forward propagation.
* @param in Input matrix.
* @return Output of this network.
*/
public Matrix forward(Matrix in){
Matrix rtn = in.clone();
for (Layer layer: this.layers){
rtn = layer.forward(rtn);
}
return rtn;
}
toStringメソッド。summaryメソッドはKerasに倣い、printまで行います。
/**
* Print summary of this network.
*/
public void summary(){
System.out.println(this.toString());
}
@Override
public String toString(){
String rtn = "Network\n";
for (Layer layer: this.layers){
rtn += layer.toString() + "\n";
}
return rtn + "----------------------------------------------------------------\n";
}
ロードやセーブメソッドはMyNetから変更し、層の数などを事前に設定しなくてもロードできるようにしました。
/**
* Save this network to one file.
* @param name Name of save file.
*/
public void save(String name){
try (
FileOutputStream fos = new FileOutputStream(name);
ObjectOutputStream oos = new ObjectOutputStream(fos);
){
oos.writeObject(this);
oos.flush();
}catch (IOException e){
System.out.println("IOException");
System.exit(-1);
}
}
/**
* Load a network from the file.
* @param name Name of load file.
*/
public static Network load(String name){
Network loadNet = null;
try (
FileInputStream fis = new FileInputStream(name);
ObjectInputStream ois = new ObjectInputStream(fis);
){
loadNet = (Network)ois.readObject();
if (Network.versionHasProblem(loadNet)){
System.exit(-1);
}else{
;
}
}catch (IOException e){
System.out.println("IOException");
System.exit(-1);
}catch (ClassNotFoundException e){
System.out.println("ClassNotFoundException");
System.exit(-1);
}
return loadNet;
}
/**
* Check that the information of version has problem.
* If the two networks have different versions, ask do you want to continue.
* @param loaded Loaded network.
* @return Do loaded network has problem?
*/
protected static boolean versionHasProblem(Network loaded){
if (!Version.version.equals(loaded.version)){
System.out.println("The versions of the two Network classes do not match.");
System.out.printf("The version of loaded is %s\n", loaded.version);
System.out.printf("The version of this is %s\n", Version.version);
System.out.println("There is a risk of serious error.");
System.out.print("Do you want to continue? [y/n] ");
Scanner sc = new Scanner(System.in);
String ans = Network.nextLine(sc);
while (!(ans.equals("y") || ans.equals("Y")
|| ans.equals("n") || ans.equals("N"))){
System.out.println("Choose 'y' or 'n'.");
System.out.print("Do you want to continue? [y/n] ");
ans = Network.nextLine(sc);
}
if (ans.equals("y") || ans.equals("Y")){
return false;
}else{
System.out.println("End.");
return true;
}
}else{
return false;
}
}
/**
* Read next line.
* If read String is Null, return "".
* @param sc Scanner instance.
* @return Readed String instance.
*/
private static String nextLine(Scanner sc) throws NoSuchElementException {
String ans;
try {
ans = sc.nextLine();
}catch (NoSuchElementException e){
ans = "";
}
return ans;
}
テスト
以下のように、2チャネル4x4の画像を10枚入力できるネットワークを作成し実行しました。
import java.util.Random;
import org.MyNet2.network.*;
import org.MyNet2.layer.*;
import org.MyNet2.actFunc.*;
import org.MyNet2.*;
public class NetworkTest {
public static void main(String[] str){
// 入力行列
Matrix4d in = new Matrix4d(new int[]{10, 2, 4, 4}, new Random(0));
// ネットワーク
Network net = new Network(
2, 4, 4, // チャネル数、入力行列サイズ
new Conv(4, new int[]{3, 3}, AFType.RELU), // 畳み込み層、カーネル数と重み行列サイズ、活性化関数のみ指定
new MaxPooling(2), // プーリング層、プールサイズのみ指定
new Dense(4, AFType.RELU), // 全結合層、ノード数のみ指定
new Dense(1, AFType.RELU)
);
net.summary();
System.out.println(net.forward(in.flatten()));
net.save("NetworkTest.net");
// ロードし、再び順方向計算
Network netLoaded = Network.load("NetworkTest.net");
netLoaded.summary();
System.out.println(netLoaded.forward(in.flatten()));
}
}
実行結果はこちら。
Network
----------------------------------------------------------------
Convolution
act: ReLU
2, 4, 4 => (3, 3) => 4, 2, 2
----------------------------------------------------------------
MaxPooling
act: null
4, 2, 2 => (2, 2) => 4, 1, 1
----------------------------------------------------------------
Dense
act: ReLU
4 => 4
----------------------------------------------------------------
Dense
act: ReLU
4 => 1
----------------------------------------------------------------
[[ 2.1788 ]
[ 2.5194 ]
[ 1.9599 ]
[ 2.2095 ]
[ 1.9963 ]
[ 1.7290 ]
[ 1.9919 ]
[ 1.8231 ]
[ 1.7836 ]
[ 1.7086 ]]
Network
----------------------------------------------------------------
Convolution
act: ReLU
2, 4, 4 => (3, 3) => 4, 2, 2
----------------------------------------------------------------
MaxPooling
act: null
4, 2, 2 => (2, 2) => 4, 1, 1
----------------------------------------------------------------
Dense
act: ReLU
4 => 4
----------------------------------------------------------------
Dense
act: ReLU
4 => 1
----------------------------------------------------------------
[[ 2.1788 ]
[ 2.5194 ]
[ 1.9599 ]
[ 2.2095 ]
[ 1.9963 ]
[ 1.7290 ]
[ 1.9919 ]
[ 1.8231 ]
[ 1.7836 ]
[ 1.7086 ]]
ロードと順方向計算が出来ています。
フルバージョン
次回は
損失関数を作成します。
[次回](https://qiita.com/tt_and_tk/items/5c647f59019a03d4cf9e