今回の目標
ここから本編
修正点
パッケージ名out_functionをactivationFunctionに変更しました。
また、Matrixクラスに以下のメソッドを追加しました。
/**
* Return absolute value of this matrix.
* @param in Matrix.
* @return Absolute value of this matrx.
*/
public void abs(){
for (int i = 0; i < this.row; i++){
for (int j = 0; j < this.col; j++){
this.matrix[i][j] = Math.abs(this.matrix[i][j]);
}
}
}
/**
* Return absolute value of a matrix.
* @param in Matrix.
* @return Absolute value of a matrx.
*/
public static Matrix abs(Matrix in){
Matrix rtn = in.clone();
for (int i = 0; i < rtn.row; i++){
for (int j = 0; j < rtn.col; j++){
rtn.matrix[i][j] = Math.abs(in.matrix[i][j]);
}
}
return rtn;
}
/**
* Calcurate average of each columns.
* @return Matrix instance that had everage of each columns in this matrix.
*/
public Matrix meanCol(){
Matrix rtn = new Matrix(new double[1][this.col]);
double num = 0;
for (int j = 0; j < this.col; j++){
num = 0;
for (int i = 0; i < this.row; i++){
num += this.matrix[i][j];
}
rtn.matrix[0][j] = num / this.row;
}
return rtn;
}
/**
* Calcurate average of each columns.
* @param in A Matrix instance.
* @return Matrix instance that had everage of each columns in a matrix.
*/
public static Matrix meanCol(Matrix in){
Matrix rtn = new Matrix(new double[1][in.col]);
double num = 0;
for (int j = 0; j < in.col; j++){
num = 0;
for (int i = 0; i < in.row; i++){
num += in.matrix[i][j];
}
rtn.matrix[0][j] = num / in.row;
}
return rtn;
}
/**
* Calucurate square root each number of this matrix.
* @return New Matrix instance.
*/
public Matrix sqrt(){
Matrix rtn = new Matrix(new double[this.row][this.col]);
for (int i = 0; i < this.row; i++){
for (int j = 0; j < this.col; j++){
rtn.matrix[i][j] = Math.sqrt(this.matrix[i][j]);
}
}
return rtn;
}
/**
* Calucurate square root each number of a matrx.
* @param in Matrix instance.
* @return New Matrix instance.
*/
public static Matrix sqrt(Matrix in){
Matrix rtn = new Matrix(new double[in.row][in.col]);
for (int i = 0; i < in.row; i++){
for (int j = 0; j < in.col; j++){
rtn.matrix[i][j] = Math.sqrt(in.matrix[i][j]);
}
}
return rtn;
}
現在のディレクトリ構成
MyNet
├── costFunction // 目的関数パッケージ(今回作成)
├── layer // 層パッケージ(すでに作成)
├── matrix // 行列パッケージ(すでに作成)
├── network // ネットワークパッケージ(すでに作成)
├── nodes // ノードパッケージ(すでに作成)
│ └── activationFunction // 活性化関数パッケージ(すでに作成)
└── optimzer // 最適化関数パッケージ(今は空)
CF.java
活性化関数の時のように、列挙型で目的関数を管理したいと思い作りました。
とりあえず平均二乗誤差と平均絶対誤差。自分は回帰しかしないと思うので。
package costFunction;
/**
* Enum class for designating cost function.
* AF is a word omitted "Cost Function".
*/
public enum CF {
MSE,
MAE,
RMSE
}
CostFunction.java
目的関数の親クラス。
メソッドの中身は適当です。
package costFunction;
import matrix.*;
/**
* Cost function's base class.
* All cost functions must extend this class.
*/
public class CostFunction {
/**
* Constructor for this class.
* Nothing to do.
*/
public CostFunction(){
;
}
/**
* Calcurate this cost function.
* @param y Matrix of network's output.
* @param t Matrix of actual data.
* @return Difference between y and t.
*/
public Matrix calcurate(Matrix y, Matrix t){
return Matrix.sub(y, t);
}
/**
* Calcurate this cost function's differential.
* @param y Matrix of network's output.
* @param t Matrix of actual data.
* @return The result of differentiating the difference between y and t.
*/
public Matrix differential(Matrix y, Matrix t){
return Matrix.sub(y, t);
}
}
MeanAbsoluteError.java
平均絶対誤差。
$$ E(w) = \frac{1}{N}\Sigma^{N}_{n=1}|y_n-t_n| $$
ここで、正解が$t$、予測値が$y$、データ量が$N$、誤差が$E$です。
package costFunction;
import matrix.*;
/**
* Cost function's base class.
* All cost functions must extend this class.
*/
public class MeanAbsoluteError extends CostFunction {
/**
* Constructor for this class.
* Nothing to do.
*/
public MeanAbsoluteError(){
;
}
/**
* Calcurate this cost function.
* @param y Matrix of network's output.
* @param t Matrix of actual data.
* @return MSE between y and t.
*/
public Matrix calcurate(Matrix y, Matrix t){
Matrix rtn = Matrix.abs(Matrix.sub(y, t));
return rtn.meanCol();
}
}
MeanSquaredError.java
平均二乗誤差。
$$ E(w) = \frac{1}{N}\Sigma^{N}_{n=1}(y_n-t_n)^2 $$
package costFunction;
import matrix.*;
/**
* Cost function's base class.
* All cost functions must extend this class.
*/
public class MeanSquaredError extends CostFunction {
/**
* Constructor for this class.
* Nothing to do.
*/
public MeanSquaredError(){
;
}
/**
* Calcurate this cost function.
* @param y Matrix of network's output.
* @param t Matrix of actual data.
* @return MAE between y and t.
*/
public Matrix calcurate(Matrix y, Matrix t){
Matrix rtn = Matrix.abs(Matrix.sub(y, t));
for (int i = 0; i < rtn.row; i++){
for (int j = 0; j < rtn.col; j++){
rtn.matrix[i][j] *= rtn.matrix[i][j];
}
}
return rtn.meanCol();
}
}
RootMeanSquaredError.java
平均平方二乗誤差。
$$ E(w) = \frac{1}{N}\sqrt{\Sigma^{N}_{n=1}(y_n-t_n)^2} $$
package costFunction;
import matrix.*;
import java.lang.Math;
/**
* Cost function's base class.
* All cost functions must extend this class.
*/
public class RootMeanSquaredError extends CostFunction {
/**
* Constructor for this class.
* Nothing to do.
*/
public RootMeanSquaredError(){
;
}
/**
* Calcurate this cost function.
* @param y Matrix of network's output.
* @param t Matrix of actual data.
* @return RMSE between y and t.
*/
public Matrix calcurate(Matrix y, Matrix t){
Matrix rtn = new Matrix(new double[1][y.col]);
double num;
for (int j = 0; j < y.col; j++){
num = 0;
for (int i = 0; i < y.row; i++){
num += Math.pow(y.matrix[i][j] - t.matrix[i][j], 2);
}
rtn.matrix[0][j] = Math.sqrt(num);
}
return Matrix.div(rtn, y.row);
}
}
test.java
今回作ったクラスを試してみました。
import matrix.Matrix;
import costFunction.*;
public class test {
public static void main(String[] str){
double[][] m = new double[5][10];
Matrix a = new Matrix(m);
Matrix b = new Matrix(m);
for (int i = 0; i < a.row; i++){
for (int j = 0; j < a.col; j++){
a.matrix[i][j] = (double)i*1.5 + (double)j*0.5;
b.matrix[i][j] = (double)i*1.4 + (double)j*0.6;
}
}
System.out.println(a);
System.out.println(b);
CostFunction f = new MeanAbsoluteError();
System.out.println(f.calcurate(a, b));
f = new MeanSquaredError();
System.out.println(f.calcurate(a, b));
f = new RootMeanSquaredError();
System.out.println(f.calcurate(a, b));
}
}
実行結果はこちら。
[[0.0000 0.5000 1.0000 1.5000 2.0000 2.5000 3.0000 3.5000 4.0000 4.5000 ]
[1.5000 2.0000 2.5000 3.0000 3.5000 4.0000 4.5000 5.0000 5.5000 6.0000 ]
[3.0000 3.5000 4.0000 4.5000 5.0000 5.5000 6.0000 6.5000 7.0000 7.5000 ]
[4.5000 5.0000 5.5000 6.0000 6.5000 7.0000 7.5000 8.0000 8.5000 9.0000 ]
[6.0000 6.5000 7.0000 7.5000 8.0000 8.5000 9.0000 9.5000 10.0000 10.5000 ]]
[[0.0000 0.6000 1.2000 1.8000 2.4000 3.0000 3.6000 4.2000 4.8000 5.4000 ]
[1.4000 2.0000 2.6000 3.2000 3.8000 4.4000 5.0000 5.6000 6.2000 6.8000 ]
[2.8000 3.4000 4.0000 4.6000 5.2000 5.8000 6.4000 7.0000 7.6000 8.2000 ]
[4.2000 4.8000 5.4000 6.0000 6.6000 7.2000 7.8000 8.4000 9.0000 9.6000 ]
[5.6000 6.2000 6.8000 7.4000 8.0000 8.6000 9.2000 9.8000 10.4000 11.0000 ]]
[[0.2000 0.1400 0.1200 0.1400 0.2000 0.3000 0.4000 0.5000 0.6000 0.7000 ]]
[[0.0600 0.0300 0.0200 0.0300 0.0600 0.1100 0.1800 0.2700 0.3800 0.5100 ]]
[[0.1095 0.0775 0.0632 0.0775 0.1095 0.1483 0.1897 0.2324 0.2757 0.3194 ]]
Excelで試したところ、以下のようになりました。
同じ結果になっていることが分かります。
次回は
最適化関数を作りたいと思います。
自分の中で理解が最も浅い部分なので、まずは勾配降下法のみのような形になると思います。