0
Help us understand the problem. What are the problem?

posted at

updated at

コンピュータとオセロ対戦50 ~行列計算~

前回

今回の目標

新しい深層学習ライブラリの作成のため、基本となるMatrixクラスを作成する。

ここから本編

これまでの方法では、強いオセロを作ることが難しいことが前回の実験で分かりました。
そのためこれからは深層強化学習、DQNと呼ばれるものをやってみようと思います。
しかし私は強化学習の経験がなく、また、今まで使用していた自作ライブラリは実行速度の点で不安があります。
そこで、強化学習について勉強する傍ら、ライブラリの作り直しを行うことにしました。
この記事では、ライブラリ内で使用する基本データ型となるMatrixクラスを作成します。

少し話は変わりますが、これらの記事で作成したライブラリ「MyNet」には、46 ~モデルの保存とロード~に示したように、以下の欠点がありました。

  • ノードクラスは必要なかったのではないか
  • 倍精度実数を使う必要はなかったのではないか
  • 二次元行列を一次元配列で表現すれば探索や計算がもっと早くできたのではないか
  • 行列同士の計算はstrassenアルゴリズムを使った方が速かったのではないか
  • 全体的に効率が悪そう
  • Matrixクラスのメソッドが、返り値があったりなかったりで分かりづらい

そこで、以下のように改善した新しいライブラリを作成することにしました。

  • ノードクラスをなくす
  • Matrxiクラスの計算メソッドを全て非破壊的に変更し、返り値を持たせる

他の部分についてですが、これらは実行速度を鑑みて改善の余地ありと判断したものです。なので本題に入る前に、これで本当に速くなるのかを調べてみたいと思います。

実行速度実験

floatとdouble

doubleはfloatに比べ、非常に精度が高いです。さらに、オーバーフローするまでの上限値も高いです(深層学習の重みでそんな値になれば、学習失敗してますが)。
そのためdoubleの方が優秀とみられているのか、Javaの実数はデフォルトでdoubleです。

public class Test {
    public static void main(String[] str){
        var a = 0.0;
        System.out.pirntln(a.getClass().getName());
    }
}

上のプログラムはコンパイルエラーが出ますが、そのメッセージ中に

Test.java:4: エラー: シンボルを見つけられません
        System.out.pirntln(a.class);
                           ^
  シンボル:   クラス a
  場所: クラス Test
エラー1個

Test.java:4: エラー: doubleは間接参照できません
        System.out.pirntln(a.getClass().getName());
                            ^
エラー1個

と出ますので、何も書いていなくても0.0はdoubleと判断されることが分かります。

ただしそんなdoubleにも弱点があり、ビット数を多くとることです。そのため、floatよりも処理速度が遅くなることは必至と思われます。実際chainerでは、32ビット実数を使用しています。
ただ、具体的にどのくらい遅くなるのかは実験してみないと分かりません。もしその差が許容できるほど小さければ、より精度の高いdoubleを使った方がよいと思われます。
こういった理由で、floatとdoubleの速度対決を行うことにしました。
使用プログラムはこちら。

public class float_vs_double {
    public static void main(String[] str){
        int[] maxNums = {1000, 1000000};
        long start, end;

        for (int maxNum: maxNums){
            System.out.printf("max num: %d\n", maxNum);

            // float start
            start = System.nanoTime();
            float numFloat = 0.0f;
            for (int i = 0; i < maxNum; i++){
                numFloat += 0.1f;
                numFloat = numFloat * 0.1f;
                numFloat = numFloat / 0.1f;
                numFloat -= 0.1f;
            }
            end = System.nanoTime();
            // float end

            System.out.printf("float time: %d [ns]\n", end - start);
            long floatTime = end - start;

            // double start
            start = System.nanoTime();
            double numDouble = 0.0;
            for (int i = 0; i < maxNum; i++){
                numDouble += 0.1;
                numDouble = numDouble * 0.1;
                numDouble = numDouble / 0.1;
                numDouble -= 0.1;
            }
            end = System.nanoTime();
            // double end

            System.out.printf("double time: %d [ns]\n", end - start);
            System.out.printf("The difference: %d [ns]\n\n", end - start - floatTime);
        }
    }
}

結果はこちら。

max num: 1000
float time: 21300 [ns]
double time: 20500 [ns]
The difference: -800 [ns]

max num: 1000000
float time: 7101500 [ns]
double time: 6795200 [ns]
The difference: -306300 [ns]

doubleの方がわずかに速い結果となりました。いちいちfloatに変換することに時間を割いたからでしょうか。
doubleを使うことにします。

一次元配列と二次元配列

一般的に、二次元配列の方が遅いとされますが、そのかわり直感的に扱えるというメリットがあります。
具体的にどの程度違いがあるのか調べてみたいと思います。

import org.MyNet.matrix.*;

public class one_vs_two_list {
    public static void main(String[] str){
        int[] sizes = {1000};//, 1000000};
        long start, end;
        
        for (int size: sizes){
            System.out.printf("size: %d\n", size);
            double num;

            // one start
            start = System.nanoTime();
            double listOne[] = new double[size * size];
            for (int i = 0; i < size; i++){
                for (int j = 0; j < size; j++){
                    num = listOne[i * size + j];
                }
            }
            end = System.nanoTime();
            // one end

            System.out.printf("one time: %d\n", end - start);

            // two start
            start = System.nanoTime();
            double listTwo[][] = new double[size][size];
            for (int i = 0; i < size; i++){
                for (int j = 0; j < size; j++){
                    num = listTwo[i][j];
                }
            }
            end = System.nanoTime();
            // two end

            System.out.printf("two time: %d\n", end - start);
        }
    }
}
size: 1000
one time: 5548300
two time: 4588800

二次元配列の方が速いという結果になりました。
乗算を使っていないからだと思われますが、ここまで時間を食うものとは知りませんでした。

二重ループと一重ループ

気になったので、一重ループと二重ループについて調べてみます。
一般的に二重ループの方が遅いといわれますが、二重ループの方が直感的に理解できるというメリットがあります。
具体的にどの程度差があるのか調べてみたいと思います。
一重ループは、除算を使用する方法とカウントしていく方法の二種類を調べました。
行列サイズ1000000も調べたかったですが、ヒープメモリが足りませんでした。

import org.MyNet.matrix.*;

public class one_vs_two {
    public static void main(String[] str){
        int[] sizes = {1000};//, 1000000};
        long start, end;
        
        for (int size: sizes){
            System.out.printf("size: %d\n", size);
            Matrix x = new Matrix(new double[size][size]);
            x.fillNum(0.1);
            double num;

            // one start
            start = System.nanoTime();
            for (int i = 0; i < size * size; i++){
                num = x.matrix[i / size][i % size];
            }
            end = System.nanoTime();
            // one end

            System.out.printf("one time div: %d\n", end - start);

            // one start
            start = System.nanoTime();
            int row = 0, col = 0;
            for (int i = 0; i < size * size; i++){
                num = x.matrix[row][col];
                col++;
                if (col >= size){
                    col = 0;
                    row++;
                }
            }
            end = System.nanoTime();
            // one end

            System.out.printf("one time inc: %d\n", end - start);

            // two start
            start = System.nanoTime();
            for (int i = 0; i < size; i++){
                for (int j = 0; j < size; j++){
                    num = x.matrix[i][j];
                }
            }
            end = System.nanoTime();
            // two end

            System.out.printf("two time: %d\n", end - start);
        }
    }
}

実行結果はこちら。

size: 1000
one time div: 6895800
one time inc: 5785500
two time: 3429400

二重ループの方が、半分ほどの時間で終わっていることが分かります。
除算やインクリメントに時間がかかったのだと思いますが、除算はまだしも足し算にここまで時間を食うとは思いませんでした。
二重ループで問題なさそうです。

strassenアルゴリズム

strassenアルゴリズムとは、行列の内積計算を素早く求められるアルゴリズムです。
数学で習った一般的な方法での内積計算をプログラムで実装すると三重ループになりますが、strassenでは7回の再帰計算で内積が行えます。
ただ上記の実験で、forループを節約したことによる計算量の増加により、forループを潤沢に使う方が結果的に素早く計算できることが分かりました。
さらに、strassenのアルゴリズムは

  • 正方行列限定
  • 行数・列数が2の累乗限定

などの制約があるようです。
そのため、採用は見送ることにします。

まとめ

  • 実数型はdouble
  • 二次元配列を使う
  • 二重ループを使う
  • strassenのアルゴリズムは使わない

for文や二次元配列を使わないためにいろいろ工夫するより、直感的に操作量を少なくした方が速くなると分かりました。
これはJavaの結果なので、他のプログラミング言語では違う結果になるかもしれません。

MyNet2

上記の実験MyNetの反省を踏まえ、実行速度を改善した新しい深層学習ライブラリを作成します。
具体的には以下のとおり。

  • ノードクラスは必要なかったのではないか
  • 倍精度実数を使う必要はなかったのではないか
  • 二次元行列を一次元配列で表現すれば探索や計算がもっと早くできたのではないか
  • 行列同士の計算はstrassenアルゴリズムを使った方が速かったのではないか
  • 全体的に効率が悪そう
  • Matrixクラスのメソッドが、返り値があったりなかったりで分かりづらい

  • ノードクラスをなくす
  • 倍精度実数を使う
  • 二次元配列を使う
  • strassenアルゴリズムは使わない
  • 無駄を省くよう意識する
  • 返り値ありで統一する

また、ライブラリを使用する際にimport文をいくつも書くのが面倒に感じたので、MatrixクラスはMyNet2では以下に置くことにします。

MyNet2
└── Matrix.java

コンストラクタ

処理を軽くするため、行列の初期化を行わないコンストラクタを追加しました。
インスタンス化後に結局他の値を入れることが多いので。

Matrix.java
    /**
     * Constructor for this class.
     * @param row Number of row.
     * @param col Number of col.
     */
    public Matrix(int row, int col){
        this.row = row;
        this.col = col;
        this.matrix = new double[this.row][this.col];
    }

    /**
     * Constructor for this class.
     * @param row Number of row.
     * @param col Number of col.
     * @param num Number to fill.
     */
    public Matrix(int row, int col, double num){
        this.row = row;
        this.col = col;

        this.matrix = new double[this.row][this.col];
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = num;
            }
        }
    }

    /**
     * Constructor for this class.
     * @param row Number of row.
     * @param col Number of col.
     * @param rand Random instance.
     * @param min Number of min for range.
     * @param max Number of max for range.
     */
    public Matrix(int row, int col, Random rand){
        this.row = row;
        this.col = col;

        this.matrix = new double[this.row][this.col];
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = rand.nextDouble();
            }
        }
    }

    /**
     * Constructor for this class.
     * @param row Number of row.
     * @param col Number of col.
     * @param rand Random instance.
     * @param min Number of min for range.
     * @param max Number of max for range.
     */
    public Matrix(int row, int col, Random rand, double min, double max){
        this.row = row;
        this.col = col;

        this.matrix = new double[this.row][this.col];
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = rand.nextDouble() * (max-min) + min;
            }
        }
    }

    /**
     * Constructor for this class.
     * @param in Two dimentional matrix of type double[][].
     */
    public Matrix(double[][] in){
        this.row = in.length;
        this.col = in[0].length;

        this.matrix = new double[this.row][this.col];
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = in[i][j];
            }
        }
    }

    /**
     * Constructor for this class.
     * @param in Two dimentional matrix of type ArrayList<ArrayList<Double>>.
     */
    public Matrix(ArrayList<ArrayList<Double>> in){
        this.row = in.size();
        this.col = in.get(0).size();

        this.matrix = new double[this.row][this.col];
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                this.matrix[i][j] = in.get(i).get(j);
            }
        }
    }

その他

その他のメソッドは、ほぼすべてにおいて必ず返り値を持たせた程度でほとんど変わりありません。
なので以下にまとめておきます。

詳細
    /**
     * Exit this program.
     * @param msg Message for printing.
     */
    protected void exit(String msg){
        System.out.println(msg);
        System.exit(-1);
    }

    /**
     * Add a matrix to this matrix.
     * @param matrix Append matrix.
     * @return New matrix instance.
     */
    public Matrix add(Matrix matrix){
        if (matrix.row != this.row || matrix.col != this.col){
            this.exit("Adding error");
        }

        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[i][j] + matrix.matrix[i][j];
            }
        }

        return rtn;
    }

    /**
     * Add a number to this matrix.
     * @param num Append number.
     * @return New matrix instance.
     */
    public Matrix add(double num){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[i][j] + num;
            }
        }

        return rtn;
    }

    /**
     * Subtract a matrix from this matrix.
     * @param matrix Matrix to subtract.
     * @return New matrix instance.
     */
    public Matrix sub(Matrix matrix){
        if (matrix.row != this.row || matrix.col != this.col){
            this.exit("Subtracting error");
        }

        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[i][j] - matrix.matrix[i][j];
            }
        }

        return rtn;
    }

    /**
     * Subtract a number from this matrix.
     * @param num Number to subtract.
     * @return New matrix instance.
     */
    public Matrix sub(double num){
        return this.add(-num);
    }

    /**
     * Multiply this matrix by a number.
     * @param num Multiplier.
     * @return New matrix instance.
     */
    public Matrix mult(double num){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[i][j] * num;
            }
        }

        return rtn;
    }

    /**
     * Divid this matrix by a number.
     * @param num Divider.
     * @return New matrix instance.
     */
    public Matrix div(double num){
        return this.mult(1 / num);
    }

    /**
     * Dot product for two matrices.
     * @param matrix Matrix to dot product.
     * @return New dot producted Matrix instance.
     */
    public Matrix dot(Matrix matrix){
        if (this.col != matrix.row){
            this.exit("dot producting error");
        }

        Matrix rtn = new Matrix(this.row, matrix.col);
        double num = 0;
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < matrix.col; j++){
                num = 0;
                for (int k = 0; k < this.col; k++){
                    num += this.matrix[i][k] * matrix.matrix[k][j];
                }
                rtn.matrix[i][j] = num;
            }
        }

        return rtn;
    }

    /**
     * Create transpose of this matrix.
     * @return New matrix instance transposed of this matrix.
     */
    public Matrix T(){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[j][i] = this.matrix[i][j];
            }
        }

        return rtn;
    }

    /**
     * Fill this matrix with a number.
     * @param num Number to fill.
     * @return New matrix instance.
     */
    public Matrix fillNum(double num){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = num;
            }
        }

        return rtn;
    }

    /**
     * Fill this matrix with random numbers has range 0~1.
     * @return New matrix instance.
     */
    public Matrix fillNextRandom(){
        return this.fillNextRandom(0);
    }

    /**
     * Fill this matrix with random numbers has range 0~1.
     * @param seed Number of seed.
     * @return New matrix instance.
     */
    public Matrix fillNextRandom(long seed){
        Random rand = new Random(seed);
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = rand.nextDouble();
            }
        }

        return rtn;
    }

    /**
     * Fill this matrix with random number has range min~max.
     * @param min Number of min for range.
     * @param max Number of max for range.
     */
    public Matrix fillRandom(double min, double max){
        return this.fillRandom(min, max, 0);
    }

    /**
     * Fill this matrix with random number has range min~max.
     * @param min Number of min for range.
     * @param max Number of max for range.
     * @param seed Number of seed.
     */
    public Matrix fillRandom(double min, double max, long seed){
        Random rand = new Random(seed);
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = rand.nextDouble()*(max-min) + min;
            }
        }

        return rtn;
    }

    /**
     * Append a number to the side of this matrix.
     * @param num Number to append.
     * @return New matrix instance.
     */
    public Matrix appendCol(double num){
        Matrix rtn = new Matrix(this.row, this.col+1);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[i][j];
            }
        }
        for (int i = 0; i < this.row; i++){
            rtn.matrix[i][this.col] = num;
        }

        return rtn;
    }

    /**
     * Stack matrices horizontally.
     * @param matrices Matrices to stack.
     *                 These should not have more than two columns.
     * @return New Matrix instance stacked.
     */
    public static Matrix hstack(Matrix ... matrices){
        Matrix rtn = new Matrix(matrices[0].row, matrices.length);
        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = matrices[j].matrix[i][0];
            }
        }

        return rtn;
    }

    /**
     * Stack matrices vertical.
     * @param matrices Matrices to stack.
     *                 These should not have more than two rows.
     * @return New Matrix instance stacked.
     */
    public static Matrix vstack(Matrix ... matrices){
        Matrix rtn = new Matrix(matrices.length, matrices[0].col);
        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = matrices[i].matrix[0][j];
            }
        }

        return rtn;
    }

    /**
     * Split this matrix vertically.
     * @param num Number of split.
     * @return Array of Matrix instance.
     */
    public Matrix[] vsplit(int num){
        Matrix[] rtn = new Matrix[num];
        int size = this.row / num;
        if (size * num != this.row){
            this.exit("vsplit error");
        }

        for (int i = 0; i < num; i++){
            rtn[i] = new Matrix(size, this.col);
            int row = i * size;
            for (int j = 0; j < size; j++){
                for (int k = 0; k < this.col; k++){
                    rtn[i].matrix[j][k] = this.matrix[row + j][k];
                }
            }
        }

        return rtn;
    }

    /**
     * Sort this matrix vertically.
     * @param order Order of sort.
     * @return Matrix instance.
     */
    public Matrix vsort(int[] order){
        Matrix rtn = new Matrix(order.length, this.col);

        for (int i = 0; i < order.length; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = this.matrix[order[i]][j];
            }
        }

        return rtn;
    }

    /**
     * Return absolute value of this matrix.
     * @return New matrix instance.
     */
    public Matrix abs(){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = Math.abs(this.matrix[i][j]);
            }
        }

        return rtn;
    }

    /**
     * Calcurate average of each columns.
     * @return Matrix instance that had everage of each columns in this matrix.
     */
    public Matrix meanCol(){
        Matrix rtn = new Matrix(1, this.col);

        double num = 0.;
        for (int j = 0; j < this.col; j++){
            num = 0;
            for (int i = 0; i < this.row; i++){
                num += this.matrix[i][j];
            }
            rtn.matrix[0][j] = num / this.row;
        }

        return rtn;
    }

    /**
     * Calcurate average of each rows.
     * @return Matrix instance that had everage of each rows in this matrix.
     */
    public Matrix meanRow(){
        Matrix rtn = new Matrix(this.row, 1);

        double num = 0.;
        for (int i = 0; i < this.row; i++){
            num = 0;
            for (int j = 0; j < this.row; j++){
                num += this.matrix[i][j];
            }
            rtn.matrix[i][0] = num / this.col;
        }

        return rtn;
    }

    /**
     * Calucurate square root each number of this matrix.
     * @return New matrix instance.
     */
    public Matrix sqrt(){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                rtn.matrix[i][j] = Math.sqrt(this.matrix[i][j]);
            }
        }

        return rtn;
    }

    /**
     * Power of a matrix element.
     * @return Multiplying a matrix by itself.
     */
    public Matrix pow(){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = Math.pow(this.matrix[i][j], 2);
            }
        }

        return rtn;
    }

    /**
     * Power of a matrix element.
     * @param num Number to power.
     * @return Multiplying a matrix by itself.
     */
    public Matrix pow(int num){
        Matrix rtn = new Matrix(this.row, this.col);
        for (int i = 0; i < rtn.row; i++){
            for (int j = 0; j < rtn.col; j++){
                rtn.matrix[i][j] = Math.pow(this.matrix[i][j], num);
            }
        }

        return rtn;
    }

    @Override
    public String toString(){
        String str = "[";

        int i = 0;
        for (double[] ele: matrix){
            if (i == 0){
                str += "[";
            }else{
                str += "\n [";
            }
            i++;
            for (double num: ele){
                str += String.format("%.4f ", num);
            }
            str += "]";
        }
        str += "]\n";

        return str;
    }

    /**
     * Method to compare this Matrix instance and a Matrix instance.
     * Without override.
     * @param o A Matrix instance.
     * @return Is equal?
     */
    public boolean equals(Matrix o){
        if (o == this){
            return true;
        }
        if (this.row != o.row || this.col != o.col){
            return false;
        }

        for (int i = 0; i < this.row; i++){
            for (int j = 0; j < this.col; j++){
                if (this.matrix[i][j] != o.matrix[i][j]){
                    return false;
                }
            }
        }

        return true;
    }

    @Override
    public Matrix clone(){
        return new Matrix(this.matrix);
    }

    @Override
    public int hashCode(){
        return (int)this.matrix[0][0];
    }

フルバージョン

次回は

層を作ります。

次回

参考文献

Register as a new user and use Qiita more conveniently

  1. You can follow users and tags
  2. you can stock useful information
  3. You can make editorial suggestions for articles
What you can do with signing up
0
Help us understand the problem. What are the problem?