Edited at

ゼロから作るDeep Learning Java編 第4章 ニューラルネットワークの学習

More than 1 year has passed since last update.


目次


4.2 損失関数

今までに登場した各種関数を簡単に呼べるようにFunctionsクラスにまとめています。


4.2.1 2乗和誤差

public static double mean_squared_error(INDArray y, INDArray t) {

INDArray diff = y.sub(t);
// 自分自身の転置行列と内積をとります。
return 0.5 * diff.mmul(diff.transpose()).getDouble(0);
}

public static double mean_squared_error2(INDArray y, INDArray t) {
// ND4Jの2乗距離関数を使います。
return 0.5 * (double)y.squaredDistance(t);
}

INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.097500000000000031, mean_squared_error(y, t), 5e-6);
assertEquals(0.097500000000000031, mean_squared_error2(y, t), 5e-6);
// LossFunctions.LossFunction.MSEを使っても実現できます。
assertEquals(0.097500000000000031, LossFunctions.score(t, LossFunctions.LossFunction.MSE, y, 0, 0, false), 5e-6);
y = Nd4j.create(new double[] {0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0});
assertEquals(0.59750000000000003, mean_squared_error(y, t), 5e-6);
assertEquals(0.59750000000000003, mean_squared_error2(y, t), 5e-6);
assertEquals(0.59750000000000003, LossFunctions.score(t, LossFunctions.LossFunction.MSE, y, 0, 0, false), 5e-6);


4.2.2 交差エントロピー誤差

public static double cross_entropy_error(INDArray y, INDArray t) {

double delta = 1e-7;
// Python: return -np.sum(t * np.log(y + delta))
return -t.mul(Transforms.log(y.add(delta))).sumNumber().doubleValue();
}

INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.51082545709933802, cross_entropy_error(y, t), 5e-6);
// LossFunctionsを使って実現することもできます。
assertEquals(0.51082545709933802, LossFunctions.score(t, LossFunctions.LossFunction.MCXENT, y, 0, 0, false), 5e-6);
y = Nd4j.create(new double[] {0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0});
assertEquals(2.3025840929945458, cross_entropy_error(y, t), 5e-6);


4.2.3 ミニバッチ学習

サンプルをランダムに抽出するためにND4JのDataSetクラスを使用します。

// MNISTデータセットを読み込みます。

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
assertArrayEquals(new int[] {60000, 784}, train.normalizedImages().shape());
assertArrayEquals(new int[] {60000, 10}, train.oneHotLabels().shape());
// ランダムに10枚のイメージを抽出します。
// 一度DataSetにイメージとラベルを格納し、サンプルとして指定枚数分を取り出します。
DataSet ds = new DataSet(train.normalizedImages(), train.oneHotLabels());
DataSet sample = ds.sample(10);
assertArrayEquals(new int[] {10, 784}, sample.getFeatureMatrix().shape());
assertArrayEquals(new int[] {10, 10}, sample.getLabels().shape());
// 取得されたサンプルのイメージとラベル値の対応があっていることを確認するために
// サンプルのイメージをPNGファイルとして書き出します。
// one-hot形式のラベルから元のラベル値へ変換します。(各行の最大値のインデックスを求めます)
INDArray indexMax = Nd4j.getExecutioner().exec(new IAMax(sample.getLabels()), 1);
if (!Constants.SampleImagesOutput.exists())
Constants.SampleImagesOutput.mkdirs();
for (int i = 0; i < 10; ++i) {
// ファイル名は"(連番)-(ラベル値).png"となります。
File f = new File(Constants.SampleImagesOutput,
String.format("%05d-%d.png",
i, indexMax.getInt(i)));
MNISTImages.writePngFile(sample.getFeatures().getRow(i), train.rows, train.columns, f);
}


4.2.4 [バッチ対応版]交差エントロピー誤差の実装

public static double cross_entropy_error2(INDArray y, INDArray t) {

int batch_size = y.size(0);
return -t.mul(Transforms.log(y.add(1e-7))).sumNumber().doubleValue() / batch_size;
}

// 単一データの場合
INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.51082545709933802, cross_entropy_error2(y, t), 5e-6);
// バッチサイズ=2の場合(同一データが2件)
t = Nd4j.create(new double[][] {
{0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
{0, 0, 1, 0, 0, 0, 0, 0, 0, 0}});
y = Nd4j.create(new double[][] {
{0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0},
{0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0}});
assertEquals(0.51082545709933802, cross_entropy_error2(y, t), 5e-6);
// todo: one-hot表現でない場合の交差エントロピー誤差の実装


4.3 数値微分


4.3.1 微分

public static double numerical_diff_bad(DoubleUnaryOperator f, double x) {

double h = 10e-50;
return (f.applyAsDouble(x + h) - f.applyAsDouble(x)) / h;
}

assertEquals(0.0, (float)1e-50, 1e-52);


4.3.2 数値微分の例

public static double numerical_diff(DoubleUnaryOperator f, double x) {

double h = 1e-4;
return (f.applyAsDouble(x + h) - f.applyAsDouble(x - h)) / (h * 2);
}

public double function_1(double x) {
return 0.01 * x * x + 0.1 * x;
}

public double function_1_diff(double x) {
return 0.02 * x + 0.1;
}
assertEquals(0.200, numerical_diff(this::function_1, 5), 5e-6);
assertEquals(0.300, numerical_diff(this::function_1, 10), 5e-6);
assertEquals(0.200, function_1_diff(5), 5e-6);
assertEquals(0.300, function_1_diff(10), 5e-6);


4.3.3 偏微分

public double function_2(INDArray x) {

double x0 = x.getDouble(0);
double x1 = x.getDouble(1);
return x0 * x0 + x1 * x1;
}

DoubleUnaryOperator function_tmp1 = x0 -> x0 * x0 + 4.0 * 4.0;
assertEquals(6.00, numerical_diff(function_tmp1, 3.0), 5e-6);
DoubleUnaryOperator function_tmp2 = x1 -> 3.0 * 3.0 + x1 * x1;
assertEquals(8.00, numerical_diff(function_tmp2, 4.0), 5e-6);


4.4 勾配

public double function_2(INDArray x) {

double x0 = x.getFloat(0);
double x1 = x.getFloat(1);
return x0 * x0 + x1 * x1;
// または
// return x.mul(x).sumNumber().doubleValue();
// あるいは以下のように転置行列と内積をとることもできます。
// return x.mmul(x.transpose()).getDouble(0);
}

assertEquals("[6.00,8.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {3.0, 4.0}))));
assertEquals("[0.00,4.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {0.0, 2.0}))));
assertEquals("[6.00,0.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {3.0, 0.0}))));


4.4.1 勾配法

public static INDArray gradient_descent(INDArrayFunction f, INDArray init_x, double lr, int step_num) {

INDArray x = init_x;
for (int i = 0; i < step_num; ++i) {
INDArray grad = Functions.numerical_gradient(f, x);
INDArray y = x.sub(grad.mul(lr));
// System.out.printf("step:%d x=%s grad=%s x'=%s%n", i, x, grad, y);
x = y;
}
return x;
}

// lr = 0.1
INDArray init_x = Nd4j.create(new double[] {-3.0, 4.0});
INDArray r = gradient_descent(this::function_2, init_x, 0.1, 100);
assertEquals("[-0.00,0.00]", Util.string(r));
assertEquals(-6.11110793e-10, r.getDouble(0), 5e-6);
assertEquals(8.14814391e-10, r.getDouble(1), 5e-6);
// 学習率が大きすぎる例: lr = 10.0
r = gradient_descent(this::function_2, init_x, 10.0, 100);
// Pythonの結果とは同じになりませんが、いずれにしても正しい結果は得られません。
assertEquals("[-763,389.44,1,017,852.62]", Util.string(r));
// 学習率が小さすぎる例: lr = 1e-10
r = gradient_descent(this::function_2, init_x, 1e-10, 100);
assertEquals("[-3.00,4.00]", Util.string(r));


4.4.2 ニューラルネットワークに対する勾配

static class simpleNet {

/** 重み */
public final INDArray W;

/**
* 重みを0.0から1.0の範囲の乱数で初期化します。
*/

public simpleNet() {
try (Random r = new DefaultRandom()) {
// 2x3のガウス分布に基づく乱数の行列を作成します。
W = r.nextGaussian(new int[] {2, 3});
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
* 本書と結果が一致することを確認するため重みを
* 外部から与えることができるようにします。
*/

public simpleNet(INDArray W) {
this.W = W.dup(); // 防衛的にコピーします。
}

public INDArray predict(INDArray x) {
return x.mmul(W);
}

public double loss(INDArray x, INDArray t) {
INDArray z = predict(x);
INDArray y = Functions.softmax(z);
double loss = Functions.cross_entropy_error(y, t);
return loss;
}
}

// 重みは乱数ではなく本書と同じ値を与えてみます。
INDArray W = Nd4j.create(new double[][] {
{0.47355232, 0.9977393, 0.84668094},
{0.85557411, 0.03563661, 0.69422093},
});
simpleNet net = new simpleNet(W);
assertEquals("[[0.47,1.00,0.85],[0.86,0.04,0.69]]", Util.string(net.W));
INDArray x = Nd4j.create(new double[] {0.6, 0.9});
INDArray p = net.predict(x);
assertEquals("[1.05,0.63,1.13]", Util.string(p));
assertEquals(2, Functions.argmax(p).getInt(0));
INDArray t = Nd4j.create(new double[] {0, 0, 1});
assertEquals(0.92806853663411326, net.loss(x, t), 5e-6);
// 関数定義はラムダ式を使っています。
INDArrayFunction f = dummy -> net.loss(x, t);
INDArray dW = Functions.numerical_gradient(f, net.W);
assertEquals("[[0.22,0.14,-0.36],[0.33,0.22,-0.54]]", Util.string(dW));


4.5 学習アルゴリズムの実装


4.5.1 2層ニューラルネットワークのクラス

2層ニューラルネットワークのクラスはTwoLayerNetです。重みとバイアスはMapではなくてTwoLayerParamsに格納します。乱数はND4JのRandamインタフェースを使用します。

私の環境では5分から10分程度かかります。

TwoLayerNet net = new TwoLayerNet(784, 100, 10);

assertArrayEquals(new int[] {784, 100}, net.parms.get("W1").shape());
assertArrayEquals(new int[] {1, 100}, net.parms.get("b1").shape());
assertArrayEquals(new int[] {100, 10}, net.parms.get("W2").shape());
assertArrayEquals(new int[] {1, 10}, net.parms.get("b2").shape());
try (Random r = new DefaultRandom()) {
INDArray x = r.nextGaussian(new int[] {100, 784});
INDArray t = r.nextGaussian(new int[] {100, 10});
INDArray y = net.predict(x);
assertArrayEquals(new int[] {100, 10}, y.shape());
Params grads = net.numerical_gradient(x, t);
assertArrayEquals(new int[] {784, 100}, grads.get("W1").shape());
assertArrayEquals(new int[] {1, 100}, grads.get("b1").shape());
assertArrayEquals(new int[] {100, 10}, grads.get("W2").shape());
assertArrayEquals(new int[] {1, 10}, grads.get("b2").shape());
}


4.5.2 ミニバッチ学習の実装

MNISTのデータを使うと非常に時間がかかります。私の環境ではループ1回につき90秒程度かかるので、10000回ループすると10日ほどかかる計算になります。

// MNISTデータセットを読み込みます。

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
assertArrayEquals(new int[] {60000, 784}, x_train.shape());
assertArrayEquals(new int[] {60000, 10}, t_train.shape());
List<Double> train_loss_list = new ArrayList<>();
int iters_num = 10000;
// int train_size = images.size(0);
int batch_size = 100;
double learning_rate = 0.1;
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
// batch_size分のデータをランダムに取り出します。
for (int i = 0; i < iters_num; ++i) {
long start = System.currentTimeMillis();
// ミニバッチの取得
DataSet ds = new DataSet(x_train, t_train);
DataSet sample = ds.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
Params grad = network.numerical_gradient(x_batch, t_batch);
network.parms.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
// 学習経過の記録
double loss = network.loss(x_batch, t_batch);
train_loss_list.add(loss);
System.out.printf("iteration %d loss=%f elapse=%dms%n",
i, loss, System.currentTimeMillis() - start);
}


4.5.3 テストデータで評価

MNISTのデータを使うと非常に時間がかかります。私の環境ではループ1回につき90秒程度かかるので、10000回ループすると10日ほどかかる計算になります。そのため最後まで実行したことはありませんが、訓練データとテストデータの認識精度が共に90%以上になるまで訓練するのに要した時間は4.6時間でループ回数は214回でした。本にあるグラフで分かるとおり80%くらいまではかなり早く上がるので、10000回ループするのではなく、認識精度がしきい値を超えた時点で打ち切るようにした方がいいかもしれません。

// MNISTデータセットを読み込みます。

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertArrayEquals(new int[] {60000, 784}, x_train.shape());
assertArrayEquals(new int[] {60000, 10}, t_train.shape());
List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iters_num = 10000;
int train_size = x_train.size(0);
int batch_size = 100;
double learning_rate = 0.01;
int iter_per_epoch = Math.max(train_size / batch_size, 1);
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
// batch_size分のデータをランダムに取り出します。
for (int i = 0; i < iters_num; ++i) {
long start = System.currentTimeMillis();
// ミニバッチの取得
DataSet ds = new DataSet(x_train, t_train);
DataSet sample = ds.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
Params grad = network.numerical_gradient(x_batch, t_batch);
network.parms.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
// 学習経過の記録
double loss = network.loss(x_batch, t_batch);
train_loss_list.add(loss);
// 1エポックごとに認識制度を計算
if (i % iter_per_epoch == 0) {
double train_acc = network.accuracy(x_train, t_train);
double test_acc = network.accuracy(x_test, t_test);
train_acc_list.add(train_acc);
test_acc_list.add(test_acc);
System.out.printf("train acc, test acc | %s, %s%n",
train_acc, test_acc);
}
System.out.printf("iteration %d loss=%f elapse=%dms%n",
i, loss, System.currentTimeMillis() - start);
}