LoginSignup
2
4

More than 5 years have passed since last update.

Javaで【ゼロから作るDeep Learning】3.ニューラルネットワーク

Posted at

はじめに

Javaで【ゼロから作るDeep Learning】2.NumPyなんてものは、ない。
の続きです。ようやくニューラルネットワークに入り、ディープラーニングっぽくなってきました。

活性化関数

まず、二次元配列の全ての要素に指定の関数を適応させる共通メソッドです。

ArrayUtil.java
public double[][] apply(double[][] x, DoubleUnaryOperator op){

    validate(x);
    double[][] result = new double[x.length][x[0].length];
    for (int i = 0; i < result.length; i++){
        for (int j = 0; j < result[0].length; j++){
            result[i][j] = op.applyAsDouble(x[i][j]);
        }
    }

    return result;
}

書籍のP48「3.2.4 シグモイド関数の実装」です。Javaにあって良かった、Math.exp。

ArrayUtil.java
private DoubleUnaryOperator sigmoid = p ->1/(1+ Math.exp(-1 * p));
public double[][] sigmoid(double[][] x){
    return apply(x, sigmoid);
}

つづきまして、P52「3.2.7 ReLU関数」。

ArrayUtil.java
private DoubleUnaryOperator relu = p -> p < 0 ? 0 : p;
public double[][] relu(double[][] x){
    return apply(x, relu);
}

ニューラルネットワーク

P65「3.4.3 実装のまとめ」

ArrayUtilTest.java
public void newralnetwork(){
    // init_network()
    double[][] W1 = {{0.1, 0.3, 0.5},{0.2, 0.4, 0.6}};
    double[]   b1 = {0.1, 0.2, 0.3};
    double[][] W2 = {{0.1,0.4},{0.2,0.5},{0.3,0.6}};
    double[]   b2 = {0.1, 0.2};
    double[][] W3 = {{0.1,0.3},{0.2,0.4}};
    double[]   b3 = {0.1, 0.2};

    // x - np.array([1.0, 0.5])
    double[][] x = {{1.0, 0.5}};

    // a1 = np.dot(x,W1)+b1
    double[][] a1 = target.plus( target.multi(x, W1), b1);

    // z1 = sigmoid(a1)
    double[][] z1 = target.sigmoid(a1);

    assertThat(z1[0][0], is(closeTo(0.57444252, 0.00001)));
    assertThat(z1[0][1], is(closeTo(0.66818777, 0.00001)));
    assertThat(z1[0][2], is(closeTo(0.75026011, 0.00001)));

    // a2 = np.dot(z1,W2)+b2
    double[][] a2 = target.plus( target.multi(z1, W2), b2);

    // z2 = sigmoid(a2)
    double[][] z2 = target.sigmoid(a2);
    assertThat(z2[0][0], is(closeTo(0.62624937, 0.00001)));
    assertThat(z2[0][1], is(closeTo(0.7710107, 0.00001)));

    // a3 = np.dot(z2,W3)+b3
    double[][] a3 = target.plus( target.multi(z2, W3), b3);

    // print(y) #[0.31682708,0.69627909]
    assertThat(a3[0][0], is(closeTo(0.31682708, 0.00001)));
    assertThat(a3[0][1], is(closeTo(0.69627909, 0.00001)));
}

出力層の設計

書籍 P69「3.5.2 ソフトマックス関数の実装上の注意」より、ソフトマックス関数を実装。

ArrayUtil.java
public double[] softmax(double[] x){
    double maxValue = Arrays.stream(x).max().getAsDouble();
    double[] value = Arrays.stream(x).map(y-> Math.exp(y - maxValue)).toArray();
    double total = Arrays.stream(value).sum();
    return Arrays.stream(value).map(p -> p/total).toArray();
}

public double[][] softmax(double[][] x){
    double[][] result = new double[x.length][];
    for (int i = 0; i < result.length; i++){
        result[i] = softmax(x[i]);
    }
    return result;
}
ArrayUtilTest.java
ArrayUtil target = new ArrayUtil();

@Test
public void softmax(){
    double[] x = {1010, 1000, 990};
    double[] expected = {9.99954600e-01, 4.53978686e-05, 2.06106005e-09};

    double[] result = target.softmax(x);
    assertThat(result[0], is(closeTo(expected[0], 0.00001)));
    assertThat(result[1], is(closeTo(expected[1], 0.00001)));
    assertThat(result[2], is(closeTo(expected[2], 0.00001)));

    assertThat(Arrays.stream(result).sum(), is(closeTo(1, 0.00001)));
}

おわりに

ひとまずニューラルネットワークとソフトマックス関数が実装できました。
ここらへんから、それ単体では理解できるが、全体としてどうなっているのか理解ができなくなっていきます。書籍通りの出力が出ているので、間違えてはいないはず。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4