Help us understand the problem. What is going on with this article?

ゼロから作るDeep Learning Java編 第3章 ニューラルネットワーク

More than 1 year has passed since last update.

目次

3.2 活性化関数

3.2.3 ステップ関数の実装

ステップ関数は以下のように実装できます。ここでは単純な関数(DoubleFunction)でINDArrayをマップできるようにしてみました。

public static double step_function(double x) {
    if (x > 0)
        return 1.0;
    else
        return 0.0;
}

public static <T extends Number> INDArray map(INDArray x, DoubleFunction<T> func) {
    int size = x.length();
    INDArray result = Nd4j.create(size);
    for (int i = 0; i < size; ++i)
        result.put(0, i, func.apply(x.getDouble(i)));
    return result;
}

public static INDArray step_function(INDArray x) {
    return map(x, d -> d > 0.0 ? 1 : 0);
}

INDArray x = Nd4j.create(new double[] {-1.0, 1.0, 2.0});
assertEquals("[-1.00,1.00,2.00]", Util.string(x));
assertEquals("[0.00,1.00,1.00]", Util.string(step_function(x)));

3.2.4 シグモイド関数の実装

シグモイド関数は以下のように実装できます。ND4JにはTransformsクラスにsigmoid関数が用意されているので、これを利用することもできます。

public static double sigmoid(double x) {
    return (double)(1.0 / (1.0 + Math.exp(-x)));
}

public static INDArray sigmoid(INDArray x) {
    // Javaでは演算子のオーバーロードができないので
    // メソッド呼び出しで記述します。
    return Transforms.exp(x.neg()).add(1.0).rdiv(1.0);
    // あるいは前述のmapを使って以下のように実装することもできます。
    // return map(x, d -> sigmoid(d));
}


INDArray x = Nd4j.create(new double[] {-1.0, 1.0, 2.0});
assertEquals("[0.27,0.73,0.88]", Util.string(sigmoid(x)));
assertEquals("[0.27,0.73,0.88]", Util.string(Transforms.sigmoid(x)));
INDArray t = Nd4j.create(new double[] {1.0, 2.0, 3.0});
// A.rdiv(k)はkをAの各要素で割ったものになります。
assertEquals("[1.00,0.50,0.33]", Util.string(t.rdiv(1.0)));
}

3.2.7 ReLU関数

ReLU関数は以下のように実装できます。ND4JにはTransformsクラスにrelu関数が用意されているのでこれを使うこともできます。

public static INDArray relu(INDArray x) {
    return map(x, d -> Math.max(0.0, d));
}

INDArray x = Nd4j.create(new double[] {-4, -2, 0, 2, 4});
assertEquals("[0.00,0.00,0.00,2.00,4.00]", Util.string(relu(x)));
assertEquals("[0.00,0.00,0.00,2.00,4.00]", Util.string(Transforms.relu(x)));

3.3 多次元配列の計算

3.3.1 多次元配列

// 1次元の配列
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals("[1.00,2.00,3.00,4.00]", Util.string(A));
// ND4Jでは1次元配列は1×Nの2次元配列となります。
assertArrayEquals(new int[] {1, 4}, A.shape());
// ND4Jでは次元数はrank()メソッドで求めます。
assertEquals(2, A.rank());
assertEquals(1, A.size(0));  // 行数
assertEquals(4, A.size(1));  // 列数
// 2次元の配列
INDArray B = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertEquals("[[1.00,2.00],[3.00,4.00],[5.00,6.00]]", Util.string(B));
assertEquals(2, B.rank());
assertArrayEquals(new int[] {3, 2}, B.shape());

3.3.2 行列の積

ND4Jでは内積をINDArray.mmul(INDArray)で求めます。

INDArray A = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
assertArrayEquals(new int[] {2, 2}, A.shape());
INDArray B = Nd4j.create(new double[][] {{5, 6}, {7, 8}});
assertArrayEquals(new int[] {2, 2}, B.shape());
assertEquals("[[19.00,22.00],[43.00,50.00]]", Util.string(A.mmul(B)));

A = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
assertArrayEquals(new int[] {2, 3}, A.shape());
B = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertArrayEquals(new int[] {3, 2}, B.shape());
assertEquals("[[22.00,28.00],[49.00,64.00]]", Util.string(A.mmul(B)));

INDArray C = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
assertArrayEquals(new int[] {2, 2}, C.shape());
assertArrayEquals(new int[] {2, 3}, A.shape());
try {
    // ND4Jでは内積をとる行列の要素数に誤りがある場合、
    // ND4JIllegalStateExceptionをスローします。
    A.mmul(C);
    fail();
} catch (ND4JIllegalStateException e) {
    assertEquals(
        "Cannot execute matrix multiplication: [2, 3]x[2, 2]: "
        + "Column of left array 3 != rows of right 2"
        , e.getMessage());
}

A = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertArrayEquals(new int[] {3, 2}, A.shape());
B = Nd4j.create(new double[] {7, 8});
assertArrayEquals(new int[] {1, 2}, B.shape());
// ND4Jでは1次元配列は1×N行の行列となるため
// 積を求める場合はtranspose()メソッドで転置する必要があります。
assertArrayEquals(new int[] {2, 1}, B.transpose().shape());
assertEquals("[23.00,53.00,83.00]", Util.string(A.mmul(B.transpose())));

3.3.3 ニューラルネットワークの行列の積

INDArray X = Nd4j.create(new double[] {1, 2});
assertArrayEquals(new int[] {1, 2}, X.shape());
INDArray W = Nd4j.create(new double[][] {{1, 3, 5}, {2, 4, 6}});
assertEquals("[[1.00,3.00,5.00],[2.00,4.00,6.00]]", Util.string(W));
assertArrayEquals(new int[] {2, 3}, W.shape());
INDArray Y = X.mmul(W);
assertEquals("[5.00,11.00,17.00]", Util.string(Y));

3.4 3層ニューラルネットワークの実装

3.4.2 各層における信号伝達の実装

INDArray X = Nd4j.create(new double[] {1.0, 0.5});
INDArray W1 = Nd4j.create(new double[][] {{0.1, 0.3, 0.5}, {0.2, 0.4, 0.6}});
INDArray B1 = Nd4j.create(new double[] {0.1, 0.2, 0.3});
assertArrayEquals(new int[] {2, 3}, W1.shape());
assertArrayEquals(new int[] {1, 2}, X.shape());
assertArrayEquals(new int[] {1, 3}, B1.shape());
INDArray A1 = X.mmul(W1).add(B1);
INDArray Z1 = Transforms.sigmoid(A1);
assertEquals("[0.30,0.70,1.10]", Util.string(A1));
assertEquals("[0.57,0.67,0.75]", Util.string(Z1));

INDArray W2 = Nd4j.create(new double[][] {{0.1, 0.4}, {0.2, 0.5}, {0.3, 0.6}});
INDArray B2 = Nd4j.create(new double[] {0.1, 0.2});
assertArrayEquals(new int[] {1, 3}, Z1.shape());
assertArrayEquals(new int[] {3, 2}, W2.shape());
assertArrayEquals(new int[] {1, 2}, B2.shape());
INDArray A2 = Z1.mmul(W2).add(B2);
INDArray Z2 = Transforms.sigmoid(A2);
assertEquals("[0.52,1.21]", Util.string(A2));
assertEquals("[0.63,0.77]", Util.string(Z2));

INDArray W3 = Nd4j.create(new double[][] {{0.1, 0.3}, {0.2, 0.4}});
INDArray B3 = Nd4j.create(new double[] {0.1, 0.2});
INDArray A3 = Z2.mmul(W3).add(B3);
// ND4JにはTransformsクラスにidentity(INDArray)メソッドが用意されています。
INDArray Y = Transforms.identity(A3);
assertEquals("[0.32,0.70]", Util.string(A3));
assertEquals("[0.32,0.70]", Util.string(Y));
// Y.equals(A3)は真となります。
assertEquals(A3, Y);

3.4.3 実装のまとめ

public static Map<String, INDArray> init_network() {
    Map<String, INDArray> network = new HashMap<>();
    network.put("W1", Nd4j.create(new double[][] {{0.1, 0.3, 0.5}, {0.2, 0.4, 0.6}}));
    network.put("b1", Nd4j.create(new double[] {0.1, 0.2, 0.3}));
    network.put("W2", Nd4j.create(new double[][] {{0.1, 0.4}, {0.2, 0.5}, {0.3, 0.6}}));
    network.put("b2", Nd4j.create(new double[] {0.1, 0.2}));
    network.put("W3", Nd4j.create(new double[][] {{0.1, 0.3}, {0.2, 0.4}}));
    network.put("b3", Nd4j.create(new double[] {0.1, 0.2}));
    return network;
}

public static INDArray forward(Map<String, INDArray> network, INDArray x) {
    INDArray W1 = network.get("W1");
    INDArray W2 = network.get("W2");
    INDArray W3 = network.get("W3");
    INDArray b1 = network.get("b1");
    INDArray b2 = network.get("b2");
    INDArray b3 = network.get("b3");

    INDArray a1 = x.mmul(W1).add(b1);
    INDArray z1 = Transforms.sigmoid(a1);
    INDArray a2 = z1.mmul(W2).add(b2);
    INDArray z2 = Transforms.sigmoid(a2);
    INDArray a3 = z2.mmul(W3).add(b3);
    INDArray y = Transforms.identity(a3);
    return y;
}

Map<String, INDArray> network = init_network();
INDArray x = Nd4j.create(new double[] {1.0, 0.5});
INDArray y = forward(network, x);
assertEquals("[0.32,0.70]", Util.string(y));

3.5 出力層の設計

3.5.1 恒等関数とソフトマックス関数

INDArray a = Nd4j.create(new double[] {0.3, 2.9, 4.0});
// 指数関数
INDArray exp_a = Transforms.exp(a);
assertEquals("[1.35,18.17,54.60]", Util.string(exp_a));
// 指数関数の和
Number sum_exp_a = exp_a.sumNumber();
assertEquals(74.1221542102, sum_exp_a.doubleValue(), 5e-6);
// ソフトマックス関数
INDArray y = exp_a.div(sum_exp_a);
assertEquals("[0.02,0.25,0.74]", Util.string(y));

3.5.2 ソフトマックス関数実装上の注意

public static INDArray softmax_wrong(INDArray a) {
    INDArray exp_a = Transforms.exp(a);
    Number sum_exp_a = exp_a.sumNumber();
    INDArray y = exp_a.div(sum_exp_a);
    return y;
}

public static INDArray softmax_right(INDArray a) {
    Number c = a.maxNumber();
    INDArray exp_a = Transforms.exp(a.sub(c));
    Number sum_exp_a = exp_a.sumNumber();
    INDArray y = exp_a.div(sum_exp_a);
    return y;
}

INDArray a = Nd4j.create(new double[] {1010, 1000, 990});
// 正しく計算されない
assertEquals("[NaN,NaN,NaN]", Util.string(Transforms.exp(a).div(Transforms.exp(a).sumNumber())));
Number c = a.maxNumber();
assertEquals("[0.00,-10.00,-20.00]", Util.string(a.sub(c)));
assertEquals("[1.00,0.00,0.00]", Util.string(Transforms.exp(a.sub(c)).div(Transforms.exp(a.sub(c)).sumNumber())));

// 間違い
assertEquals("[NaN,NaN,NaN]", Util.string(softmax_wrong(a)));
// 正しい
assertEquals("[1.00,0.00,0.00]", Util.string(softmax_right(a)));
// ND4Jには正しいsoftmax(INDArray)が用意されています。
assertEquals("[1.00,0.00,0.00]", Util.string(Transforms.softmax(a)));

3.5.3 ソフトマックス関数の特徴

INDArray a = Nd4j.create(new double[] {0.3, 2.9, 4.0});
INDArray y = Transforms.softmax(a);
assertEquals("[0.02,0.25,0.74]", Util.string(y));
// 総和は1になります。
assertEquals(1.0, y.sumNumber().doubleValue(), 5e-6);

3.6 手書き数字認識

3.6.1 MNISTデータセット

MNISTのデータの読み込みはMNISTImagesクラスを使って行います。このクラスはTHE MNIST DATABASE of handwritten digitsにあるデータをダウンロードして解凍した後のファイルを読み込みます。

// MNISTデータセットはMNISTImagesクラスに読み込みます。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
assertEquals(60000, train.size);
assertEquals(784, train.imageSize);
assertEquals(10000, test.size);
assertEquals(784, test.imageSize);

// 訓練データの先頭の100イメージをPNGとして出力します。
if (!Constants.TrainImagesOutput.exists())
    Constants.TrainImagesOutput.mkdirs();
for (int i = 0; i < 100; ++i) {
    File image = new File(Constants.TrainImagesOutput,
        String.format("%05d-%d.png", i, train.label(i)));
    train.writePngFile(i, image);
}

実際に出力したイメージは以下の通りです。ファイル名は"[連番5桁]-[ラベル].png"です。

00000-5.png 00000-5.png
00001-0.png 00001-0.png
00002-4.png 00002-4.png
00003-1.png 00003-1.png
00004-9.png 00004-9.png
00005-2.png 00005-2.png
00006-1.png 00006-1.png
00007-3.png 00007-3.png
00008-1.png 00008-1.png
00009-4.png 00009-4.png
.....

3.6.2 ニューラルネットワークの推論処理

サンプルウェイトデータ(sample_weight.pkl)SampleWeightクラスを使って読み込みます。ただしsample_weight.pklはPythonでシリアライズしたデータなので直接Javaで読み込むことはできません。そこでsample_weight.pklを以下のPythonプログラムを使って一度テキスト化しています。テキスト化後のデータはSampleWeight.txtです。SampleWeightクラスはこのテキストファイルを読み込みます。

sample_weight.py
import pickle
import numpy

pkl = "sample_weight.pkl"
with open(pkl, "rb") as f:
    network = pickle.load(f)
for k, v in network.items():
    print(k, end="")
    dim = v.ndim
    for d in v.shape:
        print("", d, end="")
    print()
    for e in v.flatten():
        print(e)

実際に読み込んで推論を行うコードは以下の通りです。認識精度はゼロから作るDeep Learningと同じく93.52%となります。

static INDArray normalize(byte[][] images) {
    int imageCount = images.length;
    int imageSize = images[0].length;
    INDArray norm = Nd4j.create(imageCount, imageSize);
    for (int i = 0; i < imageCount; ++i)
        for (int j = 0; j < imageSize; ++j)
            norm.putScalar(i, j, (images[i][j] & 0xff) / 255.0);
    return norm;
}

static INDArray predict(Map<String, INDArray> network, INDArray x) {
    INDArray W1 = network.get("W1");
    INDArray W2 = network.get("W2");
    INDArray W3 = network.get("W3");
    INDArray b1 = network.get("b1");
    INDArray b2 = network.get("b2");
    INDArray b3 = network.get("b3");

    // 以下のようにするとバッチ処理でエラーとなります。
    // INDArray a1 = x.mmul(W1).add(b1);
    // x.mmul(W1)の結果が2次元配列なのにb1が1次元であるためです。
    // add(INDArray)は自動的にブロードキャストしません。
    // 以下のように明示的にブロードキャストすることもできます。
    // INDArray a1 = x.mmul(W1).add(b1.broadcast(x.size(0), b1.size(1)));
    INDArray a1 = x.mmul(W1).addRowVector(b1);
    INDArray z1 = Transforms.sigmoid(a1);
    INDArray a2 = z1.mmul(W2).addRowVector(b2);
    INDArray z2 = Transforms.sigmoid(a2);
    INDArray a3 = z2.mmul(W3).addRowVector(b3);
    INDArray y = Transforms.softmax(a3);

    return y;
}

// テスト用のイメージを読み込ます。
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
// サンプルウェイトデータを読み込みます。
Map<String, INDArray> network = SampleWeight.read(Constants.SampleWeights);
// イメージを正規化します(0-255 -> 0.0-1.0)
INDArray x = test.normalizedImages();
int size = x.size(0);
int accuracy_cnt = 0;
for (int i = 0; i < size; ++i) {
    INDArray y = predict(network, x.getRow(i));
    // 最後の引数1は次元を表します。
    INDArray max = Nd4j.getExecutioner().exec(new IAMax(y), 1);
    if (max.getInt(0) == test.label(i))
        ++accuracy_cnt;
}
//        System.out.printf("Accuracy:%f%n", (double) accuracy_cnt / size);
assertEquals(10000, size);
assertEquals(9352, accuracy_cnt);

3.6.3 バッチ処理

INDArrayからバッチサイズ分のデータを取り出すためにNDArrayIndexを使用します。NDArrayIndex.interval(int start, int end)でstart番目からend番目までの要素を取り出すことができます(start ≦ i < end)。またIAMaxクラスを使ってNumPyのargmax関数と同じことができます。

int batch_size = 100;
// テスト用のイメージを読み込ます。
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
// サンプルウェイトデータを読み込みます。
Map<String, INDArray> network = SampleWeight.read(Constants.SampleWeights);
// イメージを正規化します(0-255 -> 0.0-1.0)
INDArray x = test.normalizedImages();
int size = x.size(0);
int accuracy_cnt = 0;
for (int i = 0; i < size; i += batch_size) {
    // バッチサイズ分のイメージを取り出してpredict()を呼びます。
    INDArray y = predict(network, x.get(NDArrayIndex.interval(i, i + batch_size)));
    // 最後の引数1は次元を表します。
    INDArray max = Nd4j.getExecutioner().exec(new IAMax(y), 1);
    for (int j = 0; j < batch_size; ++j)
        if (max.getInt(j) == test.label(i + j))
            ++accuracy_cnt;
}
//        System.out.printf("Accuracy:%f%n", (double) accuracy_cnt / size);
assertEquals(10000, size);
assertEquals(9352, accuracy_cnt);

私の環境ではバッチ処理化によって2.7秒から0.5秒に高速化できました。GPUは使っていません。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away