LoginSignup
9
19

More than 5 years have passed since last update.

Javaで【ゼロから作るDeep Learning】1.とりあえず、微分と偏微分

Last updated at Posted at 2018-02-24

はじめに

「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装」という書籍がある。2回読んだが、まあ、分かったようなぁ~、分からないようなぁ~
そもそもPythonで実装しているため、Java開発者としては、なんか誤魔化されているような気がする。動的型付けのため、呼び出す側が何を渡すかで、同じメソッドの引数が、時に数字であり、時に配列であり、、、トリッキーすぎる、、
素直にDeeplearning4jの学習すればよいものを、「そうだ、Javaで実装してみよう」と相成りました。実装するだけのため、解説は書籍をご参照下さい。

とりあえず、微分

そもそもJavaで微分や勾配が実装できるのか(P97 4.3 数値微分/P03 4.4 勾配)。できなければ、どうにもなりそうにないので、とりあえず実験してみた(Java8以降)。

ArrayUtil.java
private static double h = 1e-4; // 非常に小さい数
public double numericalDiff(DoubleUnaryOperator func, double x){
    return (func.applyAsDouble(x + h) - func.applyAsDouble(x-h))/ (2*h);
}

テストの内容は、P103。書籍通りのため、いけてる、と見なす。

ArrayUtilTest.java
@Test
public void numericalDiff1(){
    assertThat(target.numericalDiff(p-> p*p+4*4, 3.0), is(6.00000000000378));
    assertThat(target.numericalDiff(p-> 3*3+p*p, 4.0), is(7.999999999999119));
}

つづいて偏微分

書籍のP104を実装。書籍(Python)の実装では、元の値をtmp_val に代入して、計算後、元の値に戻している。しかしそれをJavaでやると、参照先が同じため、結局元データが変更される。そのため、深いコピーを使用して、元データを保持している。 → 代入直後に計算すれば、問題ないとコメントをいただきました。ごもっともです。

ArrayUtil.java
private static double h = 1e-4; // 非常に小さい数
public double[][] numericalGradient(ToDoubleFunction<double[][]> func, double[][] x){

    int cntRow = x.length;
    int cntCol = x[0].length;

    double[][] result = new double[cntRow][cntCol];
    for (int i=0; i < cntRow; i++){
        for (int j=0; j < cntCol; j++){

            double[][] xPlus = deepCopy(x);
            xPlus[i][j] = xPlus[i][j] + h;

            double[][] xMinus = deepCopy(x);
            xMinus[i][j] = xMinus[i][j] - h;

            result[i][j] = (func.applyAsDouble(xPlus) - func.applyAsDouble(xMinus))/ (2*h);
        }
    }

    return result;
}

public double[][] deepCopy(double[][] x){
    double[][] copy = new double[x.length][];
    for (int i = 0; i < copy.length; i++){
        copy[i] = new double[x[i].length];
        System.arraycopy(x[i], 0, copy[i], 0, x[i].length);
    }
    return copy;
}

テストの内容は、P104。同じく、書籍通りのため、いけてる、と見なす。

ArrayUtilTest.java
@Test
public void numericalGradient(){

    ToDoubleFunction<double[][]> function = p-> p[0][0] * p[0][0] + p[0][1]*p[0][1];
    double[][] x = {{3,4}};
    double[][] result = target.numericalGradient(function, x);

    assertThat(result[0][0], is(6.00000000000378));
    assertThat(result[0][1], is(7.999999999999119));

    result = target.numericalGradient(function, new double[][]{{0,2}});

    assertThat(result[0][0], is(closeTo(0.0, 0.000001)));
    assertThat(result[0][1], is(closeTo(4.0, 0.000001)));
}

おわりに

微分、偏微分は大丈夫そうである。
ちなみに、一応全部?実装しました。問題はPCが遅くて、最終的にちゃんとした結果を出力しているのか、検証できないorz

9
19
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
19