LoginSignup
5
8

More than 3 years have passed since last update.

Javaでディープラーニングを実装してみた

Posted at

ディープラーニングをスクラッチから実装して理解するアプローチの書籍としては、『ゼロから作るDeeplearning』が有名ですが、同じ趣旨の『ディープラーニングの数学』という書籍がとても分かりやすかったので、PythonのコードJavaで実装してみました。

Javaは機械学習不毛地帯(というかデータサイエンス領域がPythonの独壇場?)ですが、Javaで機械学習のロジックを理解したいという奇特な方がいれば参考にしてください。

オリジナルと同様にできる限りライブラリを使わず、スクラッチで実装しました。
(実行速度の問題で行列やベクトルの演算部分はcommons mathライブラリを利用しています)
また、書籍の内容に合わせたので、Javaっぽくない書き方や冗長になっている部分もあります。

書籍『ディープラーニングの数学』の紹介

まず最初に書籍の内容を簡単に紹介したいと思います。

書籍は4編に分かれていて、導入編が機械学習の基本、理論編が数学、実装編がPythonによる機械学習アルゴリズムの実装、発展編が少し高度な内容の数式もしくは概念レベルの解説(実装はなし)になっています。

この書籍の特徴は以下の3点だと思います。
1. 章同士の依存関係が可視化されていて全体を俯瞰しやすい
2. 数学は公式を羅列するだけでなく、式の導出まできっちりやっている
(厳密ではない部分もありますが、数学の専門書ではないのでちょうどいいバランスだと思います)
3. 実装はシンプルな線形回帰から始まって段階的にディープラーニングに進化していくので、各段階の差分を理解するだけで少しずつ理解を深めていくことができる
特に3は重要なポイントだと思います。

各編の概要は以下のとおりです。

導入編

導入編は1章のみで、機械学習の概要や仕組みについて解説されています。
また、例として身長から体重を予測するという回帰問題を平方完成という式変形で解析的に解いています。

理論編

理論編では主には数学について解説されていますが、ディープラーニングに必要な範囲に絞られています。
2章から6章まであり、2章は微分・積分、3章はベクトル・行列、4章は多変数関数の微分(偏微分)、5章は指数関数・対数関数、6章は確率・統計となっています。
レベルとしては高校数学から大学理系学部の教養科目ぐらいでしょうか。

章構成もよく練られていて、前から読み進めていくと矛盾なく理解できるようになっています。

この記事の中で数学は解説しませんが、ディープラーニングの仕組みを深く理解したい人は、理論編を読んでから後半の実装部分に取り組んだほうがいいと思います。

実装編

実践編ではPythonで機械学習のロジックを実装しています。
7章から10章まであり、7章は線形回帰モデル(回帰問題)、8章はロジスティック回帰モデル(2値分類問題)、9章はロジスティック回帰モデル(多値分類問題)、そして10章でディープラーニングを実装します。

実装編も章構成が工夫されていて、前の章に1つか2つの概念を追加すると次の章のロジックが実現できるようになっているので、段階的に理解することができます。
また、巻頭の見開きに各章の間の差分が表形式でまとめられているので、混乱したら見てみてください。

発展編

発展編は11章のみで、画像に強いCNNモデル、時系列データに強いRNNモデル、数値計算で勾配を求める数値微分の考え方、勾配降下法より効率がいい最適化アルゴリズム、過学習対策、重み行列の初期化方法などについて、数式や概念レベルで解説されています。

線形回帰モデル(回帰問題)のJava実装

それではここからJavaで実装していきます。
線形回帰モデルでは、The Boston Housing Datasetという不動産・地域データを使って不動産の価格を予測します。

線形単回帰モデル

最初は部屋数(RM)のみの1変数を使う線形単回帰モデルを実装します。

コンストラクタでは主に以下の処理を行います。
・ハイパーパラメータ(学習回数と学習率)を設定する
・学習データと正解データを作成する
・重みベクトルを1で初期化する

learnメソッドでは以下の3つの処理を繰り返して学習します。
・予測値を計算する
・誤差を計算する
・勾配に学習率を掛けて重みを更新する

LinearSingleRegression.java
package math.deeplearning.ch07;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 * 線形単回帰モデル.
 */
public class LinearSingleRegression {
    // 学習率
    private double alpha;
    // 学習回数
    private int iters;
    // 学習データ
    private RealMatrix x;
    // 正解データ
    private RealVector yt;
    // 入力データ行数
    private int M;
    // 入力データ列数
    private int D;
    // 重みベクトル
    private RealVector W;

    /**
     * 初期化処理.
     *
     * @param iters 学習回数
     * @param alpha 学習率
     */
    public LinearSingleRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        // The Boston Housing Datasetを読み込む
        RealMatrix boston = loadBoston();
        // 学習データとして部屋数(RM)列を抽出し、ダミー変数1を付加する
        x = addBiasCol(extractCol(boston, new int[]{5}));
        // 正解データとして物件価格を抽出する
        yt = boston.getColumnVector(13);
        // 学習データの行数
        M = x.getRowDimension();
        // 学習データの列数
        D = x.getColumnDimension();

        // 重みベクトルを1で初期化する
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        // 学習回数を5000、学習率を0.01に設定する
        LinearSingleRegression lsr = new LinearSingleRegression(50000, 0.01);
        // 学習する
        lsr.learn();
    }

    /**
     * 学習する.
     */
    public void learn() {
        for (int i = 0; i < iters; i++) {
            // 予測値ypを計算
            RealVector yp = dot(x, W);
            // 誤差ydを計算
            RealVector yd = sub(yp, yt);
            // 勾配に学習率を掛けて重みを更新
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            // 一定回数学習するごとに誤差を表示する
            if (i % 100 == 0)
                System.out.println(i + " " + mean(pow(yd, 2)) / 2);
        }
    }
}

煩わしい処理をmath.deeplearning.common.Utilクラスに詰め込んだ結果、機械学習のロジック部分はある程度シンプルに実装できました。
実行すると学習が進むにつれて誤差が小さくなっていくことが分かります。

線形単回帰モデルの出力
0 154.2249338409091
100 29.617518011568446
・・・
49800 21.80032626850963
49900 21.800325071320316

線形重回帰モデル

続いて部屋数(RM)と低所得者率(LSTAT)の2変数を使う線形重回帰モデルを実装します。

LinearMultipleRegression.java
package math.deeplearning.ch07;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 * 線形重回帰モデル.
 */
public class LinearMultipleRegression {
    // 学習率
    private double alpha;
    // 学習回数
    private int iters;
    // 学習データ
    private RealMatrix x;
    // 正解データ
    private RealVector yt;
    // 入力データ行数
    private int M;
    // 入力データ列数
    private int D;
    // 重みベクトル
    private RealVector W;

    /**
     * 初期化処理.
     *
     * @param iters 学習回数
     * @param alpha 学習率
     */
    public LinearMultipleRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        // The Boston Housing Datasetを読み込む
        RealMatrix boston = loadBoston();
        // 学習データとして部屋数(RM)列と低所得者率(LSTAT)列を抽出し、ダミー変数1を付加する
        x = addBiasCol(extractCol(boston, new int[]{5, 12}));
        // 正解データとして物件価格を抽出する
        yt = boston.getColumnVector(13);
        // 学習データの行数
        M = x.getRowDimension();
        // 学習データの列数
        D = x.getColumnDimension();

        // 重みベクトルを1で初期化する
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        // 学習回数を2000、学習率を0.001に設定する
        LinearMultipleRegression lmr = new LinearMultipleRegression(2000, 0.001);
        // 学習する
        lmr.learn();
    }

    /**
     * 学習する.
     */
    public void learn() {
        for (int i = 0; i < iters; i++) {
            // 予測値ypを計算
            RealVector yp = dot(x, W);
            // 誤差ydを計算
            RealVector yd = sub(yp, yt);
            // 勾配に学習率を掛けて重みを更新
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            // 一定回数学習するごとに誤差を表示する
            if (i % 100 == 0)
                System.out.println(i + " " + mean(pow(yd, 2)) / 2);
        }
    }
}

特徴量を1つ追加したことで線形単回帰モデルより誤差が小さくなり、学習も早く収束しています。

線形重回帰モデルの出力
0 112.06398160770748
100 25.358934200838444
・・・
1800 15.280256759397282
1900 15.280228371672587

線形単回帰モデルのコードと比較すると、実質的な差分は下記の学習データを抽出する部分のみで、線形回帰のロジック部分は変更せずに線形重回帰に対応できていることが分かります。

線形単回帰モデル
// 学習データとして、部屋数(RM)列を抽出し、ダミー変数1を付加する
x = addBiasCol(extractCol(boston, new int[]{5}));
線形重回帰モデル
// 学習データとして、部屋数(RM)列と低所得者率(LSTAT)列を抽出し、ダミー変数1を付加する
x = addBiasCol(extractCol(boston, new int[]{5, 12}));

ロジスティック回帰モデル(2値分類問題)のJava実装

次はIris Data Setというアヤメのサイズデータを使ってアヤメの種類を2クラスに分類するロジスティック回帰モデルを実装します。
Iris Data Setは先頭から50件目までがSetosa、51件目から100件目までがVersicolourというアヤメの種類のデータになっているので、先頭から100件のデータを抽出し、並び順をシャッフルします。

// Iris Data SetからSetosaとVersicolourの2種類のアヤメのデータを読み込む
RealMatrix iris = shuffle(loadIris(0, 100));

処理の流れは線形回帰と同じですが、今回はロジスティック回帰モデルで2値分類問題を解くので、活性化関数にsigmoid関数を入れて出力を0から1の確率値に変換します。

2値ロジスティック回帰モデル

BinaryLogisticRegression.java
package math.deeplearning.ch08;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 * ロジスティック回帰モデル(2値分類).
 */
public class BinaryLogisticRegression {
    // 学習率
    private double alpha;
    // 学習回数
    private int iters;
    // 学習データ
    private RealMatrix x;
    // 評価用学習データ
    private RealMatrix xTest;
    // 正解データ
    private RealVector yt;
    // 評価用正解データ
    private RealVector ytTest;
    // 入力データ行数
    private int M;
    // 入力データ列数
    private int D;
    // 重みベクトル
    private RealVector W;

    /**
     * 初期化処理.
     *
     * @param iters 学習回数
     * @param alpha 学習率
     */
    public BinaryLogisticRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        // Iris Data SetからSetosaとVersicolourの2種類のアヤメのデータを読み込む
        RealMatrix iris = shuffle(loadIris(0, 100));
        // 学習データとしてがく片の長さの列とがく片の幅の列を抽出し、ダミー変数1を付加する
        x = addBiasCol(extractRowCol(iris, 0, 69, 0, 1));
        // テストデータとしてがく片の長さの列とがく片の幅の列を抽出し、ダミー変数1を付加する
        xTest = addBiasCol(extractRowCol(iris, 70, 99, 0, 1));
        // 学習の正解データとしてアヤメの種類を抽出する
        yt = extractRowCol(iris, 0, 69, 4);
        // テストの正解データとしてアヤメの種類を抽出する
        ytTest = extractRowCol(iris, 70, 99, 4);
        // 正解データの行数
        M = x.getRowDimension();
        // 正解データの列数
        D = x.getColumnDimension();

        // 重みベクトルを1で初期化する
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        // 学習回数を10000、学習率を0.01に設定する
        BinaryLogisticRegression blr = new BinaryLogisticRegression(10000, 0.01);
        // 学習する
        blr.learn();
    }

    /**
     * 学習する.
     */
    public void learn() {
        // 学習する
        for (int i = 0; i < iters; i++) {
            // 予測値ypを計算
            RealVector yp = sigmoid(dot(x, W));
            // 誤差ydを計算
            RealVector yd = sub(yp, yt);
            // 勾配に学習率を掛けて重みを更新
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            // 一定回数学習するごとに誤差と精度を表示する
            if (i % 10 == 0) {
                RealVector p = sigmoid(dot(xTest, W));
                System.out.print("iter = " + i + "\tloss = " + crossEntropy(ytTest, p));
                System.out.println("\tscore = " + calcAccuracy(ytTest, p));
            }
        }
    }
}

最初、正解率は50%前後ですが、最終的には100%になりました。

2値ロジスティック回帰モデルの出力
iter = 0    loss = 4.401398657630698    score = 0.5333333333333333
iter = 10   loss = 3.4820950219350593   score = 0.5333333333333333
・・・
iter = 9980 loss = 0.10275578614196225  score = 1.0
iter = 9990 loss = 0.10270185332637241  score = 1.0

線形回帰のコードと比較すると、実質的な差分は下記の予測値ypの計算にsigmoid関数を追加している部分だけであることが分かります。
(あと今回から学習データとテストデータを分けていますが、本質的な違いではありません)

線形回帰モデル
// 予測値ypを計算
RealVector yp = dot(x, W);
ロジスティック回帰モデル
// 予測値ypを計算
RealVector yp = sigmoid(dot(x, W));

ロジスティック回帰モデル(多値分類問題)のJava実装

続いて同じIris Data Setをアヤメのデータをすべて使ってアヤメの種類を3クラスに分類するロジスティック回帰モデルを実装します。

問題設定が2クラス分類から3クラス分類になるので、活性化関数をsigmoid関数からsoftmax関数に変更します。

多値ロジスティック回帰モデル

MultipleLogisticRegression.java
package math.deeplearning.ch09;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 * ロジスティック回帰モデル(多値分類).
 */
public class MultipleLogisticRegression {
    // 学習率
    private double alpha;
    // 学習回数
    private int iters;
    // 学習データ
    private RealMatrix x;
    // 評価用学習データ
    private RealMatrix xTest;
    // 正解データ
    private RealMatrix yt;
    // 評価用正解データ
    private RealMatrix ytTest;
    // 入力データ行数
    private int M;
    // 入力データ列数
    private int D;
    // 重み行列
    private RealMatrix W;

    /**
     * 初期化処理.
     *
     * @param iters 学習回数
     * @param alpha 学習率
     */
    public MultipleLogisticRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        // Iris Data SetからSetosaとVersicolourの2種類のアヤメのデータを読み込む
        RealMatrix iris = shuffle(loadIris());
        // 学習データとしてがく片の長さの列と花弁の長さの列を抽出し、ダミー変数1を付加する
        x = addBiasCol(extractRowCol(iris, 0, 74, new int[]{0, 2}));
        // テストデータとしてがく片の長さの列と花弁の長さの列を抽出し、ダミー変数1を付加する
        xTest = addBiasCol(extractRowCol(iris, 75, 149, new int[]{0, 2}));
        // 4変数すべてを使う場合
        // x = addBiasCol(extractRowCol(iris, 0, 74, 0, 3));
        // xTest = addBiasCol(extractRowCol(iris, 75, 149, 0, 3));
        // 学習の正解データとしてアヤメの種類を抽出し、OneHotVector形式に変換する
        yt = oneHotEncode(extractRowCol(iris, 0, 74, 4), 3);
        // テストの正解データとしてアヤメの種類を抽出し、OneHotVector形式に変換する
        ytTest = oneHotEncode(extractRowCol(iris, 75, 149, 4), 3);
        // 正解データの行数
        M = x.getRowDimension();
        // 正解データの列数
        D = x.getColumnDimension();

        // 重み行列を1で初期化する
        W = add(MatrixUtils.createRealMatrix(D, 3), 1.0);
    }

    public static void main(String[] args) throws Exception {
        // 学習回数を10000、学習率を0.01に設定する
        MultipleLogisticRegression mlr = new MultipleLogisticRegression(10000, 0.01);
        mlr.learn();
    }

    /**
     * 学習する.
     */
    public void learn() {
        // 学習する
        for (int i = 0; i < iters; i++) {
            // 予測値ypを計算
            RealMatrix yp = softmax(dot(x, W));
            // 誤差ydを計算
            RealMatrix yd = sub(yp, yt);
            // 勾配に学習率を掛けて重みを更新
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            // 一定回数学習するごとに誤差と精度を表示する
            if (i % 10 == 0) {
                RealMatrix p = softmax(dot(xTest, W));
                System.out.print("iter = " + i + "\tloss = " + crossEntropy(ytTest, p));
                System.out.println("\tscore = " + calcAccuracy(ytTest, p));
            }
        }
    }
}

最初の正解率は1/3前後ですが、最終的には97%になりました。

多値ロジスティック回帰モデルの出力
iter = 0    loss = 1.089863468306522    score = 0.30666666666666664
iter = 10   loss = 1.0505735104517255   score = 0.30666666666666664
・・・
iter = 9980 loss = 0.18412409250145656  score = 0.9733333333333334
iter = 9990 loss = 0.18403868595917505  score = 0.9733333333333334

2値分類のロジスティック回帰モデルのコードと比較すると、主な差分は以下の3点です。
・重みをベクトルから行列に変更(クラスごとに重みを持つため)
・正解データを0/1のバイナリ形式からOneHotVector形式に変更(2値分類から多値分類になったため)
・予測値ypの計算の活性化関数をsigmoid関数からsoftmax関数に変更

ディープラーニングモデルのJava実装

それではいよいよディープラーニングモデルを実装します。

ここではMNISTという手書き数字の画像データを0から9の10クラスに分類する問題を解きます。

3層ディープラーニングモデル

まずは隠れ層が1層のみの3層ディープラーニングモデルを実装します。

DeepLearning.java
package math.deeplearning.ch10;

import org.apache.commons.math3.linear.RealMatrix;
import java.io.IOException;
import java.util.*;
import static math.deeplearning.common.Util.*;

/**
 * ディープラーニング(隠れ層1層).
 */
public class DeepLearning {
    // 学習データの行数
    private int M;
    // 学習データの列数(画像のピクセル数)
    private int D;
    // 分類クラス数
    private int N;
    // 学習回数
    private int iters;
    // バッチデータサイズ
    private int batchSize;
    // 学習率
    private double alpha;

    // MNIST画像データ
    private RealMatrix xAll;
    private RealMatrix xTest;
    private RealMatrix ytAll;
    private RealMatrix ytTest;

    // 重み行列
    private RealMatrix V;
    private RealMatrix W;

    public DeepLearning(int iters, int H, int batchSize, double alpha) throws IOException {
        // MNISTデータセットを読み込む
        xAll = addBiasCol(div(loadMnistImage(MNIST_TRAIN_IMAGE_FILE_NAME), 255));
        xTest = addBiasCol(div(loadMnistImage(MNIST_TEST_IMAGE_FILE_NAME), 255));
        ytAll = oneHotEncode(loadMnistLabel(MNIST_TRAIN_LABEL_FILE_NAME), 10);
        ytTest = oneHotEncode(loadMnistLabel(MNIST_TEST_LABEL_FILE_NAME), 10);

        M = xAll.getRowDimension();
        D = xAll.getColumnDimension();
        N = ytAll.getColumnDimension();

        this.iters = iters;
        this.batchSize = batchSize;
        this.alpha = alpha;

        // 重み行列をHe Normalで初期化
        V = initW(D, H);
        W = initW(H + 1, N);
    }

    public static void main(String... args) throws Exception {
        // 学習回数を10000、隠れ層のニューロン数を128、バッチサイズを512、学習率を0.01に設定する
        DeepLearning dl = new DeepLearning(10000, 128, 512, 0.01);
        // 学習する
        dl.learn();
    }

    public void learn() {
        // ランダムサンプリングのindexを初期化
        List<Integer> indexes = new ArrayList<>();
        for (int i = 0; i < M; i++) indexes.add(i);

        for (int i = 0; i < iters; i++) {
            // 学習データのサンプリング
            List<Integer> index = randIndex(indexes, M, batchSize);
            RealMatrix x = sampling(xAll, index);
            RealMatrix yt = sampling(ytAll, index);

            // 各層の出力値を計算
            RealMatrix a = dot(x, V);
            RealMatrix b = reLU(a);
            RealMatrix b1 = addBiasCol(b);
            RealMatrix u = dot(b1, W);
            RealMatrix yp = softmax(u);
            // 各層の誤差を計算
            RealMatrix yd = sub(yp, yt);
            RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));
            // 勾配に学習率を掛けて各層の重みを更新
            W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
            V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));

            // 一定回数学習するごとに誤差と精度を表示する
            if (i % 100 == 0) {
                RealMatrix p = softmax(dot(addBiasCol(reLU(dot(xTest, V))), W));
                System.out.print(i + " " + crossEntropy(ytTest, p) + " ");
                System.out.println(calcAccuracy(ytTest, p));
            }
        }
    }
}

最初、誤差は2.3前後、正解率は10%前後ですが、最終的に誤差は0.21前後、正解率は94%前後になります。

3層ディープラーニングモデルの出力
0 2.449633365625842 0.0951
100 1.5349024136564533 0.6818
・・・
9800 0.21109711296030495 0.9416
9900 0.21035221505955806 0.9419

重み行列が1つ増えて、入力層に対応する重み行列がV、隠れ層に対応する重み行列がWになります。
また重み行列の初期値は1固定ではなく、He Normalという手法で初期化します。

// 重み行列をHe Normalで初期化
V = initW(D, H);
W = initW(H + 1, N);

隠れ層の活性化関数はReLUを使います。

// 各層の出力値を計算
RealMatrix a = dot(x, V);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix u = dot(b1, W);
RealMatrix yp = softmax(u);

誤差逆伝播で層ごとに誤差を計算します。
ReLUの微分はstep関数で求められます。

// 各層の誤差を計算
RealMatrix yd = sub(yp, yt);
RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));

各層の勾配に学習率を掛けて各層の重みを更新します。

// 勾配に学習率を掛けて各層の重みを更新
W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));

4層ディープラーニングモデル

最後に隠れ層を1層追加した4層のディープラーニングモデルを実装します。

DeepLearning2.java
package math.deeplearning.ch10;

import org.apache.commons.math3.linear.RealMatrix;
import java.io.IOException;
import java.util.*;
import static math.deeplearning.common.Util.*;

/**
 * ディープラーニング(隠れ層2層).
 */
public class DeepLearning2 {
    // 学習データの行数
    private int M;
    // 学習データの列数(画像のピクセル数)
    private int D;
    // 分類クラス数
    private int N;
    // 学習回数
    private int iters;
    // バッチデータサイズ
    private int batchSize;
    // 学習率
    private double alpha;

    // MNIST画像データ
    private RealMatrix xAll;
    private RealMatrix xTest;
    private RealMatrix ytAll;
    private RealMatrix ytTest;

    // 重み行列
    private RealMatrix U;
    private RealMatrix V;
    private RealMatrix W;

    public DeepLearning2(int iters, int H, int batchSize, double alpha) throws IOException {
        // MNISTデータセットを読み込む
        xAll = addBiasCol(div(loadMnistImage(MNIST_TRAIN_IMAGE_FILE_NAME), 255));
        xTest = addBiasCol(div(loadMnistImage(MNIST_TEST_IMAGE_FILE_NAME), 255));
        ytAll = oneHotEncode(loadMnistLabel(MNIST_TRAIN_LABEL_FILE_NAME), 10);
        ytTest = oneHotEncode(loadMnistLabel(MNIST_TEST_LABEL_FILE_NAME), 10);

        M = xAll.getRowDimension();
        D = xAll.getColumnDimension();
        N = ytAll.getColumnDimension();

        this.iters = iters;
        this.batchSize = batchSize;
        this.alpha = alpha;

        // 重み行列をHe Normalで初期化
        U = initW(D, H);
        V = initW(H + 1, H);
        W = initW(H + 1, N);
    }

    public static void main(String... args) throws Exception {
        // 学習回数を10000、隠れ層のニューロン数を128、バッチサイズを512、学習率を0.01に設定する
        DeepLearning2 dl2 = new DeepLearning2(10000, 128, 512, 0.01);
        // 学習する
        dl2.learn();
    }

    public void learn() {
        // ランダムサンプリングのindexを初期化
        List<Integer> indexes = new ArrayList<>();
        for (int i = 0; i < M; i++) indexes.add(i);

        for (int i = 0; i < iters; i++) {
            // 学習データをサンプリング
            List<Integer> index = randIndex(indexes, M, batchSize);
            RealMatrix x = sampling(xAll, index);
            RealMatrix yt = sampling(ytAll, index);

            // 各層の出力値を計算
            RealMatrix a = dot(x, U);
            RealMatrix b = reLU(a);
            RealMatrix b1 = addBiasCol(b);
            RealMatrix c = dot(b1, V);
            RealMatrix d = reLU(c);
            RealMatrix d1 = addBiasCol(d);
            RealMatrix u = dot(d1, W);
            RealMatrix yp = softmax(u);
            // 各層の誤差を計算
            RealMatrix yd = sub(yp, yt);
            RealMatrix dd = mult(step(c), dot(yd, t(removeBias(W))));
            RealMatrix bd = mult(step(a), dot(dd, t(removeBias(V))));
            // 勾配に学習率を掛けて各層の重みを更新
            W = sub(W, mult(div(dot(t(d1), yd), batchSize), alpha));
            V = sub(V, mult(div(dot(t(b1), dd), batchSize), alpha));
            U = sub(U, mult(div(dot(t(x), bd), batchSize), alpha));

            // 一定回数学習するごとに誤差と精度を表示
            if (i % 100 == 0) {
                RealMatrix p = softmax(dot(addBiasCol(reLU(dot(addBiasCol(reLU(dot(xTest, U))), V))), W));
                System.out.print(i + " " + crossEntropy(ytTest, p) + " ");
                System.out.println(calcAccuracy(ytTest, p));
            }
        }
    }
}

3層ディープラーニングモデルの誤差は0.21前後、正解率は94%前後でしたが、4層ディープラーニングモデルは隠れ層が2層になって表現力が増したことで、誤差は0.15前後、正解率は95%前後に改善します。

4層ディープラーニングモデルの出力
0 2.418195100372308 0.1035
100 1.4860509069098333 0.6518
・・・
9800 0.15087335052084305 0.9552
9900 0.14996068028907877 0.9556

重み行列の初期化処理では、追加した隠れ層に対応する重み行列が1つ増えます。

3層ディープラーニングモデル
// 重み行列をHe Normalで初期化
V = initW(D, H);
W = initW(H + 1, N);
4層ディープラーニングモデル
// 重み行列をHe Normalで初期化
U = initW(D, H);
V = initW(H + 1, H);
W = initW(H + 1, N);

各層の出力値の計算では、追加された隠れ層1層分の処理が増えます。

3層ディープラーニングモデル
// 各層の出力値を計算
RealMatrix a = dot(x, V);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix u = dot(b1, W);
RealMatrix yp = softmax(u);
4層ディープラーニングモデル
// 各層の出力値を計算
RealMatrix a = dot(x, U);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix c = dot(b1, V);
RealMatrix d = reLU(c);
RealMatrix d1 = addBiasCol(d);
RealMatrix u = dot(d1, W);
RealMatrix yp = softmax(u);

誤差逆伝播でも同様に追加された隠れ層の処理が増えます。

3層ディープラーニングモデル
// 各層の誤差を計算
RealMatrix yd = sub(yp, yt);
RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));
4層ディープラーニングモデル
// 各層の誤差を計算
RealMatrix yd = sub(yp, yt);
RealMatrix dd = mult(step(c), dot(yd, t(removeBias(W))));
RealMatrix bd = mult(step(a), dot(dd, t(removeBias(V))));

重み行列の更新処理も追加された隠れ層に対応する重み行列の更新処理が増えます。

3層ディープラーニングモデル
// 勾配に学習率を掛けて各層の重みを更新
W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));
4層ディープラーニングモデル
// 勾配に学習率を掛けて各層の重みを更新
W = sub(W, mult(div(dot(t(d1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(b1), dd), batchSize), alpha));
U = sub(U, mult(div(dot(t(x), bd), batchSize), alpha));

隠れ層が1層の3層ディープラーニングモデルのコードと比較すると、隠れ層の重み行列と計算処理が1層分追加されているだけであることが分かります。

まとめ

Javaで線形回帰モデルを実装して回帰問題を解くことからスタートして、最終的にはシンプルなディープラーニングモデルを実装して手書き数字をある程度正しく識別することができました。
著者さんや出版社の回し者ではありませんが、ディープラーニングの仕組みを理解したい方はぜひ書籍も読んでみてください。
そして「なるほどなー!」と思った方は、ご自分の得意な言語で実装し直してみてください。
読むだけの数倍は理解が深まると思います。

ちなみに『ゼロから作るDeeplearning』のPythonのコードJavaで実装してみたのですが、7章のCNNでデータが行列から4次元のテンソルになってJavaできれいに実装できず、6章までで止まってしまっています・・・。
またモチベーションが復活したら7章も実装して記事にまとめたいと思います。

最後まで読んでいただきありがとうございました!

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8