この記事は、Javaでディープラーニングをやってみたいけど、第一歩目が踏み出せない人の背中を押すことを目的として書きました。簡単にディープラーニングの力を体験できるので、ぜひ読んでみて下さい。
対象読者
以下のいずれかの人を読者として想定しています。
- ディープラーニングをやってみたいけど、Pythonでプログラムを書いたことないからなぁ…というJavaプログラマー
- 何らかの理由により、Javaでディープラーニングしたい方
- Deeplearning4jを使ってみたい方
この記事を読むことで、IntelliJでDeeplearning4jを使ったディープラーニングの開発環境を構築して、手書きの数字を識別できるようになります。また、数字の識別の他にも多数あるディープラーニングのサンプルを実行できるようになります。ただし、この記事からはディープラーニングの理論を学ぶことはできません。ディープラーニングの理論を学ぶには「ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装」を読むことをお勧めします。
まずは体験する
ディープラーニングの学習を始める最初のステップは、学習するモチベーションを高めるために、実際にプログラムを動かして「体験すること」だと考えます。以下を見て下さい。
このプログラムは、手書きの数字を入力すると、画面下部の「Prediction: 」にそれを識別(予測)した数字を表示します。
もし、あなたが手書きの数字の「7」の画像を識別するプログラムを特殊なライブラリーを使わず書こうと思ったら、どのようなアルゴリズムにしますか?画像の上部に横方向の線があり、その終点から下方向への線があれば「7」と識別するようなアルゴリズムであれば大丈夫でしょうか。そのアルゴリズムは、以下のいずれも正しく「7」と識別することができますか?そして、それをどのようにプログラムに落とし込みますか?
人間であれば簡単な数字の識別もコンピューターには(アルゴリズムで実現しようとすると)それほど簡単なことではありません。しかし、ディープラーニングはこれを実現できます。そして、この記事では上記プログラムを実際に動かします。
システム要件
Deeplearning4jのシステム要件は以下の通りです。
- Java 1.7以上 64-Bitバージョン(JAVA_HOMEもセットすること)
- Maven または Gradle
- IntelliJ または Eclipse
- Git
検証に利用した環境
この記事を書くにあたって利用した環境は、以下の通りです。
- OS:Ubuntu 17.10
- メモリー:3.8 GB
- CPU:Intel® Core™ i5-6500 CPU @ 3.20GHz × 4
- GPU:なし(この記事の動作検証では無くても問題ありません)
- IDE: IntelliJ IDEA 2018.1.4 (Community Edition)
$ mvn -version
Apache Maven 3.5.0
Maven home: /usr/share/maven
Java version: 1.8.0_171, vendor: Oracle Corporation
Java home: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "4.13.0-21-generic", arch: "amd64", family: "unix"
$ git --version
git version 2.14.1
注意事項
ディスク容量をかなり使うので、dl4j-examples
のサブプロジェクトのdl4j-examples
だけをビルドした方がいいかもしれません。プロジェクト全体をビルドする場合は、少なくとも15GBの空きが必要です。
以下のディレクトリーに大量のjarがダウンロードされるので、git clone
したDeeplearning4jのディレクトリーを削除するときはこのディレクトリーも削除して下さい。
.m2/repository/org/deeplearning4j/
開発環境の構築手順
開発環境の構築は以下を実行するだけです。
$ git clone https://github.com/deeplearning4j/dl4j-examples.git
$ cd dl4j-examples/dl4j-examples
$ mvn clean install
ただし、画面の描画に使用するので、必要に応じてOpenJFX(Java FX)もインストールして下さい。Java FXが無いと、ビルド時に次のようなエラーになります。
[ERROR] /home/tamura/git/dl4j-examples/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/harmonies/Piano.java:[4,24] パッケージjavafx.animationは存在しません
Ubuntu 17.10であれば、以下でインストールできます。
$ sudo apt-get install openjfx
動作確認
開発環境が構築できたら、ディープラーニングにおける「Hello world!」のような位置づけ(?)の手書き数字の識別(先程見たサンプルです)を試してみましょう。ディープラーニングでは、まず大量のデータを「学習」して最適なパラメーターを導出します。そして、それをもとに「予測」を行います。手順は以下の通りです。
-
ビルドが完了したら、IntelliJでプロジェクトを開く。
-
org.deeplearning4j.examples.convolution.mnist.MnistClassifier
のソースコードを開いて、実行(エディタの左サイドにある緑三角の実行ボタンをクリック)する(このステップは「学習」を行います)。
以下のようなメッセージが出力されます。/usr/lib/jvm/java-8-openjdk-amd64/bin/java -javaagent:/home/tamura/idea-IC-181.5087.20 ・・・(略)・・・ org.deeplearning4j.examples.convolution.mnist.MnistClassifier o.d.e.c.m.MnistClassifier - Data load and vectorization... o.d.i.r.BaseImageRecordReader - ImageRecordReader: 10 label classes inferred using label generator ParentPathLabelGenerator o.d.i.r.BaseImageRecordReader - ImageRecordReader: 10 label classes inferred using label generator ParentPathLabelGenerator o.d.e.c.m.MnistClassifier - Network configuration and training... o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 1 o.n.n.Nd4jBlas - Number of threads used for BLAS: 1 o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Linux] o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [0.9GB]; o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [MKL] o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE] o.d.o.l.ScoreIterationListener - Score at iteration 0 is 2.4694731759178388 o.d.o.l.ScoreIterationListener - Score at iteration 10 is 1.078069156582683 o.d.o.l.ScoreIterationListener - Score at iteration 20 is 0.7327581484283221 ・・・(略)・・・ o.d.o.l.ScoreIterationListener - Score at iteration 1100 is 0.20279510458591593 o.d.o.l.ScoreIterationListener - Score at iteration 1110 is 0.10997898485405874 o.d.e.c.m.MnistClassifier - Completed epoch 0 o.d.e.c.m.MnistClassifier - ========================Evaluation Metrics======================== # of classes: 10 Accuracy: 0.9891 Precision: 0.9891 Recall: 0.9890 F1 Score: 0.9891 Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes) =========================Confusion Matrix========================= 0 1 2 3 4 5 6 7 8 9 --------------------------------------------------- 973 0 0 0 0 0 2 2 3 0 | 0 = 0 0 1132 0 1 0 1 1 0 0 0 | 1 = 1 2 3 1018 1 0 0 1 6 1 0 | 2 = 2 0 0 1 1000 0 3 0 4 1 1 | 3 = 3 0 0 1 0 973 0 3 0 0 5 | 4 = 4 1 0 0 5 0 882 2 1 1 0 | 5 = 5 5 2 0 0 2 3 944 0 2 0 | 6 = 6 0 2 4 0 0 0 0 1017 2 3 | 7 = 7 3 0 2 1 0 0 1 2 961 4 | 8 = 8 4 2 1 1 3 0 0 6 1 991 | 9 = 9 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times ================================================================== Process finished with exit code 0
-
org.deeplearning4j.examples.convolution.mnist.MnistClassifierUI
のソースコードを開いて、実行する -
手書きの数字を受け付けるJava FX画面が表示されるので、数字を入力する(このステップは「予測」を行います)。
ソースコードリーディング
では、これを実現する仕組みはどのようになっているのでしょうか?手書きの数字画像を「学習」するMnistClassifierのソースコードを上から見てみましょう。
これ以降のセクションの理解にはディープラーニングの基礎知識が必要です。
最初は、手書き数字画像のダウンロード先とそれを解凍する一時ディクレトリーの定数とロガーのフィールド変数です。
public class MnistClassifier {
private static final Logger log = LoggerFactory.getLogger(MnistClassifier.class);
private static final String basePath = System.getProperty("java.io.tmpdir") + "/mnist";
private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
そして、このクラスのmain()
メソッドになります。このメソッドを呼び出すと、学習を開始します。入力画像が1チャンネル、縦横それぞれ28ピクセルの3次元データとして入力層に渡ります。1から10までの数字を識別するので出力層の数は10、バッチサイズは54、エポック数は1です。
public static void main(String[] args) throws Exception {
int height = 28;
int width = 28;
int channels = 1; // single channel for grayscale images
int outputNum = 10; // 10 digits classification
int batchSize = 54;
int nEpochs = 1;
int iterations = 1;
int seed = 1234;
Random randNumGen = new Random(seed);
次に、GitHubから7万個の手書き数字画像が圧縮されたmnist_png.tar.gz
をダウンロードして、解凍します。
log.info("Data load and vectorization...");
String localFilePath = basePath + "/mnist_png.tar.gz";
if (DataUtilities.downloadFile(dataUrl, localFilePath))
log.debug("Data downloaded from {}", dataUrl);
if (!new File(basePath + "/mnist_png").exists())
DataUtilities.extractTarGz(localFilePath, basePath);
学習(訓練)用のデータ(6万件)とテスト用のデータ(1万件)に分けて、それぞれtrainIter
、testIter
のイテレーター変数に格納します。
// vectorization of train data
File trainData = new File(basePath + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// pixel values from 0-255 to 0-1 (min-max scaling)
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(trainIter);
trainIter.setPreProcessor(scaler);
// vectorization of test data
File testData = new File(basePath + "/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(scaler); // same normalization for better results
次は、学習率の設定をlrSchedule
という変数名のHashMap
に追加します。学習率が大きいと、前半は早く学習が進みますが、後半はなかなか収束しないため、処理件数に応じて、学習率を低くしています。このプログラムでは、1,111回(=学習用のデータ:60,000 / バッチサイズ:54)の学習が繰り返されます。その繰り返し回数に応じて、徐々に学習率を下げています。
log.info("Network configuration and training...");
Map<Integer, Double> lrSchedule = new HashMap<>();
lrSchedule.put(0, 0.06); // iteration #, learning rate
lrSchedule.put(200, 0.05);
lrSchedule.put(600, 0.028);
lrSchedule.put(800, 0.0060);
lrSchedule.put(1000, 0.001);
ここからニューラルネットワーク構築のメインの処理を行います。NeuralNetConfiguration.Builder()
のlayer()
メソッドで呼び出すことで、ニューラルネットワークにレイヤーを追加します。入力層は追加不要なので、最初に追加するレイヤーはConvolutionLayer
(畳み込み層)になります。その次にSubsamplingLayer
(プーリング層)を追加しています。さらに、それを繰り返し、DenseLayer
(全結合層)が続き、最後にOutputLayer
(出力層)を追加しています。画像認識でよく利用されるCNN(畳み込みニューラルネットワーク)の構成です。
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.0005)
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, lrSchedule)))
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.stride(1, 1) // nIn need not specified in later layers
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)) // InputType.convolutional for normal image
.backprop(true).pretrain(false).build();
Activation.IDENTITY
は活性化関数に恒等関数( $\scriptsize{ f(x) = x }$ 、つまり何もしない)を、Activation.RELU
はReLU関数、Activation.SOFTMAX
はソフトマックス関数を使用することを意味します。
言葉だけでは分かりづらいかもしれないので、ニューラルネットワークの構成を図解してみました。
この図とソースコードを見比べながら、分からない部分はDeeplearning4j Cheat Sheetなどを確認してみて下さい(長くなるので、全ては説明しません)。
では、次に進みます。MultiLayerNetwork
のsetListeners()
メソッドで呼び出しておくと、学習状況を定期的に出力してくれます。
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(10));
log.debug("Total num of params: {}", net.numParams());
最後に、fit()
メソッドを呼び出して、訓練用のデータを用いて学習を開始します。学習が完了したら、MultiLayerNetwork.evaluate()
にテストデータを与えて、それを評価します。最後に、導出したパラメーターをまとめてminist-model.zip
に保存します。
// evaluation while training (the score should go down)
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
log.info("Completed epoch {}", i);
Evaluation eval = net.evaluate(testIter);
log.info(eval.stats());
trainIter.reset();
testIter.reset();
}
ModelSerializer.writeModel(net, new File(basePath + "/minist-model.zip"), true);
}
}
もう一つのクラスMnistClassifierUI
は、このminist-model.zip
を読み込んでニューラルネットワークを構築し、手書きの数字画像を「予測」します。MnistClassifierUI
についての詳しい解説は省略します。
応用
MnistClassifier
のソースコードを少し変更して、いろいろな実験をしてみましょう。
学習状況をグラフ化する
MultiLayerNetwork
のsetListeners()
メソッドに与えるリスナークラスを他のものに変更してみましょう。こちらのページで紹介されていたリスナークラスをセットしてみます。
// net.setListeners(new ScoreIterationListener(10));
// 上の行をコメントアウトして、下の4行を追加
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
net.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage)));
ソースコードを修正したら、再度プログラムを実行します。以下が標準出力に出力されるので、
o.d.u.p.PlayUIServer - DL4J UI Server started at http://localhost:9000
http://localhost:9000 にアクセスすると、以下のように現在の学習の状況を分かりやすく可視化したグラフが表示されます。
画面右側の「言語(Language)」タブをクリックして、日本語を選択してあります。
「システム」タブをクリックすれば、このニューラルネットワークの構成が簡単な図で確認できます。
最適化のアルゴリズムを変更する
次に、最適化のアルゴリズムを確率的勾配降下法(Stocastic Gradient Descent:SGD)に変更してみましょう。NeuralNetConfiguration.Builder()
のupdater()
メソッドの引数のNesterovs
をSgd
に変更します。
そして、プログラムを実行すると、以下のような結果になりました。
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9698
Precision: 0.9696
Recall: 0.9697
F1 Score: 0.9697
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
969 0 1 0 0 2 3 1 4 0 | 0 = 0
0 1120 3 2 0 1 3 0 6 0 | 1 = 1
6 2 993 4 6 3 3 9 6 0 | 2 = 2
1 0 7 976 0 7 0 9 7 3 | 3 = 3
1 1 2 0 955 0 5 2 2 14 | 4 = 4
2 1 0 11 1 866 5 1 3 2 | 5 = 5
10 3 1 0 6 3 933 0 2 0 | 6 = 6
2 8 16 2 1 0 0 982 3 14 | 7 = 7
6 0 1 4 4 5 4 6 941 3 | 8 = 8
5 7 0 9 11 7 1 5 1 963 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
若干精度が落ちていますね。いくつか試してみましたが、このケースではNesterovs
(Nesterovの加速勾配降下法)がいいようです。
重みの初期値をゼロにする
あえて正しくない設定で動かしてみましょう。NeuralNetConfiguration.Builder()
のweightInit()
メソッドにWeightInit.ZERO
を与えて、重みの初期値をゼロにします。
こうすると、スコアは2.3
前後でほとんど変わることなく終了します。そして、最終的にすべての画像を「1」と予測します。
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
0 980 0 0 0 0 0 0 0 0 | 0 = 0
0 1135 0 0 0 0 0 0 0 0 | 1 = 1
0 1032 0 0 0 0 0 0 0 0 | 2 = 2
0 1010 0 0 0 0 0 0 0 0 | 3 = 3
0 982 0 0 0 0 0 0 0 0 | 4 = 4
0 892 0 0 0 0 0 0 0 0 | 5 = 5
0 958 0 0 0 0 0 0 0 0 | 6 = 6
0 1028 0 0 0 0 0 0 0 0 | 7 = 7
0 974 0 0 0 0 0 0 0 0 | 8 = 8
0 1009 0 0 0 0 0 0 0 0 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
これは、全ての重みの値が均一の更新されてしまうからですね。
まとめ
ということで、簡単にDeeplearning4jを使った手書きの数字の識別を行ってみました。これで、Javaでディープラーニングする第一歩目は踏み出せたと思います。git clone
したソースコードにはこれ以外にも多種のサンプルが含まれています。次のステップとして、他のプログラムを実行してみるのもいいかもしれません。理論を理解していない人は、前述の書籍を読むことをお勧めします。