6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ゼロから作るDeep Learning Java編 第5章 誤差逆伝播法

Last updated at Posted at 2018-03-04

目次

#5.4 単純なレイヤの実装

##5.4.1 乗算レイヤの実装

static class MulLayer {

    private double x, y;

    public double forward(double x, double y) {
        this.x = x;
        this.y = y;
        return x * y;
    }

    // Javaでは多値を返せないので配列で返します。
    public double[] backward(double dout) {
        return new double[] {dout * y, dout * x};
    }
}

double apple = 100;
double apple_num = 2;
double tax = 1.1;
// Layer
MulLayer mul_apple_layer = new MulLayer();
MulLayer mul_tax_layer = new MulLayer();
// forward
double apple_price = mul_apple_layer.forward(apple, apple_num);
double price = mul_tax_layer.forward(apple_price, tax);
assertEquals(220.0, price, 5e-6);
// backward
double dprice = 1;
double[] dapple_price_tax = mul_tax_layer.backward(dprice);
double[] dapple_num = mul_apple_layer.backward(dapple_price_tax[0]);
assertEquals(2.2, dapple_num[0], 5e-6);
assertEquals(110.0, dapple_num[1], 5e-6);
assertEquals(200.0, dapple_price_tax[1], 5e-6);

##5.4.2 加算レイヤの実装

static class AddLayer {

    public double forward(double x, double y) {
        return x + y;
    }

    public double[] backward(double dout) {
        return new double[] {dout, dout};
    }
}

double apple = 100;
double apple_num = 2;
double orange = 150;
double orange_num = 3;
double tax = 1.1;
// Layer
MulLayer mul_apple_layer = new MulLayer();
MulLayer mul_orange_layer = new MulLayer();
AddLayer add_apple_orange_layer = new AddLayer();
MulLayer mul_tax_layer = new MulLayer();
// forward
double apple_price = mul_apple_layer.forward(apple, apple_num);
double orange_price = mul_orange_layer.forward(orange, orange_num);
double all_price = add_apple_orange_layer.forward(apple_price, orange_price);
double price = mul_tax_layer.forward(all_price, tax);
// backward
double dprice = 1;
double[] dall_price = mul_tax_layer.backward(dprice);
double[] dapple_dorange_price = add_apple_orange_layer.backward(dall_price[0]);
double[] dorange = mul_orange_layer.backward(dapple_dorange_price[1]);
double[] dapple = mul_apple_layer.backward(dapple_dorange_price[0]);
assertEquals(715.0, price, 5e-6);
assertEquals(110.0, dapple[1], 5e-6);
assertEquals(2.2, dapple[0], 5e-6);
assertEquals(3.3, dorange[0], 5e-6);
assertEquals(165.0, dorange[1], 5e-6);
assertEquals(650.0, dall_price[1], 5e-6);

#5.5 活性化関数レイヤの実装

##5.5.1 ReLUレイヤ

誤差逆伝播法を使ったReLU例のクラス実装はReluです。

INDArray x = Nd4j.create(new double[][] {{1.0, -0.5}, {-2.0, 3.0}});
assertEquals("[[1.00,-0.50],[-2.00,3.00]]", Util.string(x));
// 本書とは違ったテストをします。
Relu relu = new Relu();
INDArray a = relu.forward(x);
// forwardの結果
assertEquals("[[1.00,0.00],[0.00,3.00]]", Util.string(a));
// mask
assertEquals("[[1.00,0.00],[0.00,1.00]]", Util.string(relu.mask));
INDArray dout = Nd4j.create(new double[][] {{5, 6}, {7, 8}});
INDArray b = relu.backward(dout);
// backwardの結果
assertEquals("[[5.00,0.00],[0.00,8.00]]", Util.string(b));

##5.5.2 Sigmoidレイヤ

誤差逆伝播法を使ったSigmoidレイヤの実装クラスはSigmoidです。

#5.6 Affine Softmaxレイヤの実装

##5.6.1 Affineレイヤ

誤差逆伝播法を使ったAffineレイヤの実装クラスはAffineです。

try (Random r = new DefaultRandom()) {
    INDArray X = r.nextGaussian(new int[] {2});
    INDArray W = r.nextGaussian(new int[] {2, 3});
    INDArray B = r.nextGaussian(new int[] {3});
    assertArrayEquals(new int[] {1, 2}, X.shape());
    assertArrayEquals(new int[] {2, 3}, W.shape());
    assertArrayEquals(new int[] {1, 3}, B.shape());
    INDArray Y = X.mmul(W).addRowVector(B);
    assertArrayEquals(new int[] {1, 3}, Y.shape());
}

##5.6.2 バッチ版Affineレイヤ

INDArray X_dot_W = Nd4j.create(new double[][] {{0, 0, 0}, {10, 10, 10}});
INDArray B = Nd4j.create(new double[] {1, 2, 3});
assertEquals("[[0.00,0.00,0.00],[10.00,10.00,10.00]]", Util.string(X_dot_W));
assertEquals("[[1.00,2.00,3.00],[11.00,12.00,13.00]]", Util.string(X_dot_W.addRowVector(B)));

##5.6.3 Softmax-with-Lossレイヤ

誤差逆伝播法を使ったSoftmax-with-Lossレイヤの実装クラスはSoftmaxWithLossです。

#5.7 誤差逆伝播法の実装

誤差逆伝播法を使った2層ニューラルネットワークのクラスはTwoLayerNetです。
4章にもTwoLayerNetがありますが、こちらは数値微分を使ったものです。この実装にあたって各レイヤを扱いやすくするため、以下のふたつのインタフェースを定義しています。

public interface Layer {

    INDArray forward(INDArray x);
    INDArray backward(INDArray x);

}

public interface LastLayer {

    double forward(INDArray x, INDArray t);
    INDArray backward(INDArray x);

}

##5.7.3 誤差逆伝播法の勾配確認

数値微分による方法との勾配比較は本書にあるよりもかなり大きな差となっています。そのため、ここでは数値微分による勾配を3で割ったものと比較しています。Functions.average(INDArray)はすべての要素の平均を求めるメソッドです。

public static double average(INDArray x) {
    // x.length()はすべての要素数を返します。
    return x.sumNumber().doubleValue() / x.length();
}
// MNISTの訓練データを読み込みます。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
// 正規化されたイメージとone-hotラベルの先頭3個をそれぞれ取り出します。
int batch_size = 3;
INDArray x_batch = train.normalizedImages().get(NDArrayIndex.interval(0, batch_size));
INDArray t_batch = train.oneHotLabels().get(NDArrayIndex.interval(0, batch_size));
// 勾配を数値微分によって求めます。
Params grad_numerical = network.numerical_gradient(x_batch, t_batch);
// 勾配を誤差伝播法によって求めます。
Params grad_backprop = network.gradient(x_batch, t_batch);
// 数値微分と誤差伝播法の結果を比較します。
double diff_W1 = Functions.average(Transforms.abs(grad_backprop.get("W1").sub(grad_numerical.get("W1"))));
double diff_b1 = Functions.average(Transforms.abs(grad_backprop.get("b1").sub(grad_numerical.get("b1"))));
double diff_W2 = Functions.average(Transforms.abs(grad_backprop.get("W2").sub(grad_numerical.get("W2"))));
double diff_b2 = Functions.average(Transforms.abs(grad_backprop.get("b2").sub(grad_numerical.get("b2"))));
System.out.println("W1=" + diff_W1);
System.out.println("b1=" + diff_b1);
System.out.println("W2=" + diff_W2);
System.out.println("b2=" + diff_b2);
// 差分は本書より少し大きめです。
assertTrue(diff_b1 < 1e-3);
assertTrue(diff_W2 < 1e-3);
assertTrue(diff_b2 < 1e-3);
assertTrue(diff_W1 < 1e-3);

#5.7.4 誤差逆伝播法を使った学習

数値微分を使った学習に比べるとかなり速くなります。私の環境では89秒程度で10000回のループが終了します。ただし最終的な認識精度は84%程度であり、数値微分に劣る結果となりました。勾配を数値微分の場合と比較したときの差が大きいので、おそらくレイヤの実装のどこかに誤りがあるものと思われます。

// MNISTの訓練データを読み込みます。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
// MNISTのテストデータを読み込みます。
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
DataSet dataSet = new DataSet(x_train, t_train);
int iters_num = 10000;
int train_size = x_train.size(0);
int batch_size = 100;
double learning_rate = 0.1;
List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iter_per_epoch = Math.max(train_size / batch_size, 1);
for (int i = 0; i < iters_num; ++i) {
    DataSet sample = dataSet.sample(batch_size);
    INDArray x_batch = sample.getFeatures();
    INDArray t_batch = sample.getLabels();
    // 誤差逆伝播法によって勾配を求める
    Params grad = network.gradient(x_batch, t_batch);
    // 更新
    network.params.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
    double loss = network.loss(x_batch, t_batch);
    train_loss_list.add(loss);
    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.printf("loss=%f train_acc=%f test_acc=%f%n", loss, train_acc, test_acc);
    }
}
assertTrue(train_acc_list.get(train_acc_list.size() - 1) > 0.8);
assertTrue(test_acc_list.get(test_acc_list.size() - 1) > 0.8);
6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?