LoginSignup
5
1

More than 5 years have passed since last update.

Javaで【ゼロから作るDeep Learning】2.NumPyなんてものは、ない。

Posted at

はじめに

Javaで【ゼロから作るDeep Learning】1.とりあえず、微分と偏微分
からの続きです。
当然、JavaではNumPyは使用できません、多分。類似のJavaライブラリには、GPUも使えるND4Jがありますが、使用するとそちらが本題になりそうなので、使わず。そしてJavaには、配列がある。
ということで、配列の計算を足し算とかけ算を実装します。引き算はマイナスで足し、割り算は1を割って掛ければ良し(手抜き)。Deep Learning、ほぼ関係なし。

二次元配列の検証

演算をする前に、引数の二次元配列を検証する。観点としては、要素数が0でないこと、2次元目(?)の配列の長さが全て同じこと。

ArrayUtil.java
public void validate(double[][] x){

    if (x.length ==0 || x[0].length ==0){
        throw new IllegalArgumentException();
    }

    if (Arrays.stream(x).skip(1).anyMatch(p -> x[0].length != p.length)){
        throw new IllegalArgumentException();
    }
}

二次元配列の足し算

数値を足す場合、二次元配列の全ての要素に足す

ArrayUtil.java
public double[][] plus(double[][] x, double y){
    validate(x);

    final int resultRow = x.length;
    final int resultCol = x[0].length;

    double[][] result = new double[resultRow][resultCol];
    for (int i = 0; i < resultRow; i++){
        for (int j = 0; j < resultCol; j++){
            result[i][j] = x[i][j] + y;
        }
    }

    return result;
}

一次元配列を足す場合、二次元配列の1次元目が同じ要素に足す

ArrayUtil.java
public double[][] plus(double[][] x, double[] y){
    validate(x);

    if (x[0].length != y.length){
        throw new IllegalArgumentException();
    }

    final int resultRow = x.length;
    final int resultCol = x[0].length;

    double[][] result = new double[resultRow][resultCol];
    for (int i = 0; i < resultRow; i++){
        for (int j = 0; j < resultCol; j++){
            result[i][j] = x[i][j] + y[j];
        }
    }

    return result;
}

一次元配列を足す場合、二次元配列の同じ要素に足す

ArrayUtil.java
public double[][] plus(double[][] x, double[][] y){
    validate(x);
    validate(y);

    if (x.length != y.length || x[0].length != y[0].length){
        throw new IllegalArgumentException();
    }

    final int resultRow = x.length;
    final int resultCol = x[0].length;

    double[][] result = new double[resultRow][resultCol];
    for (int i = 0; i < resultRow; i++){
        for (int j = 0; j < resultCol; j++){
            result[i][j] = x[i][j] + y[i][j];
        }
    }

    return result;
}

二次元配列の掛け算

数値を掛ける場合、二次元配列の全ての要素に掛ける

ArrayUtil.java
public double[][] multi(double[][] x, double y){
    validate(x);

    double[][] result = new double[x.length][x[0].length];
    for (int i = 0; i < result.length; i++){
        for (int j = 0; j < result[i].length; j++){
            result[i][j] = x[i][j] * y;
        }
    }

    return result;
}

二次元配列を掛ける場合、掛けたり足したりする。
書籍のP54「3.3.2 行列の内積」、もしくは、「 ■行列の積 ABの定義」を参照。

ArrayUtil.java
public double[][] multi(double[][] x, double[][] y){

    validate(x);
    validate(y);

    int cntCalc = x[0].length;
    if (cntCalc != y.length){
        throw new IllegalArgumentException();
    }

    final int resultRow = x.length;
    final int resultCol = y[0].length;

    double[][] result = new double[resultRow][resultCol];
    for (int i = 0; i < resultRow; i++){
        for (int j = 0; j < resultCol; j++){
            final int row = i;
            final int col = j;
            result[row][col] = IntStream.range(0, cntCalc).mapToDouble(k -> x[row][k] * y[k][col]).sum();
        }
    }

    return result;
}

おわりに

いまだディープラーニングには届かじ。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1