Deeplearning4j (以下、DL4J) を使って手書き文字認識をしてみたメモです。
一応、これだけあればどこでもDL4Jを試せると思います。
なお、公式のQuick Startはこちらを参照のこと。(2016年になって更新された模様)
環境
- Hardware
- CPU: Intel(R) Core(TM)2 Duo CPU @ 2.93GHz 2.93GHz
- RAM: 4.00 GB
- OS
- Windows 10 Home 1511 10586.218
- Software
- Oracle JDK 8 Update 45 (Windows x86)
- Gradle 2.13
GPUはないので使っていません。BLASは入れてません。
また、実際はIntellij IDEAで開発しました。
実装
以下のようにファイル/ディレクトリを作成しました。
./
|
+- build.gradle
|
+- src/main/java/
| |
| +- jp/hashiwa/dl4j/convolution/
| |
| +- CNNMnistCreator.java
| +- CNNMnistReader.java
|
+- logs/
ソースコード
データを学習して、学習した結果のニューラルネットワークをファイル出力するクラス(CNNMnistCreator.java)と、ニューラルネットワークのファイルを読み込んで評価するクラス(CNNMnistReader.java)の2つを作成しました。
まずは学習する方です。完全なソースコードはこちら。
public static void main(String[] args) throws Exception {
int numRows = 28;
int numColumns = 28;
int nChannels = 1;
int outputNum = 10;
int numSamples = 2000;
int batchSize = 500;
int iterations = 10;
int splitTrainNum = (int) (batchSize*.8);
int seed = 123;
int listenerFreq = iterations/5;
DataSet mnist;
SplitTestAndTrain trainTest;
DataSet trainInput;
List<INDArray> testInput = new ArrayList<>();
List<INDArray> testLabels = new ArrayList<>();
String binFile = "logs/convolution.bin";
String confFile = "logs/convolution.json";
log.info("Load data....");
DataSetIterator mnistIter = new MnistDataSetIterator(batchSize,numSamples, true);
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)
.learningRate(0.01)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list(6)
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(nChannels)
.stride(1, 1)
.nOut(20)
// .dropOut(0.5)
.activation("relu")
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.nIn(nChannels)
.stride(1, 1)
.nOut(50)
// .dropOut(0.5)
.activation("relu")
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation("relu")
.nOut(500)
.build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation("softmax")
.build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder, numRows, numColumns, nChannels);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
log.info("Train model....");
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
while(mnistIter.hasNext()) {
mnist = mnistIter.next();
trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
trainInput = trainTest.getTrain(); // get feature matrix and labels for training
testInput.add(trainTest.getTest().getFeatureMatrix());
testLabels.add(trainTest.getTest().getLabels());
model.fit(trainInput);
}
log.info("Evaluate weights....");
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
for(int i = 0; i < testInput.size(); i++) {
INDArray output = model.output(testInput.get(i));
eval.eval(testLabels.get(i), output);
}
log.info(eval.stats());
log.info("Save model....");
try (OutputStream fos = new FileOutputStream(binFile);
DataOutputStream dos = new DataOutputStream(fos)) {
Nd4j.write(model.params(), dos);
}
FileUtils.writeStringToFile(new File(confFile), model.getLayerWiseConfigurations().toJson());
log.info("****************Example finished********************");
}
使用する手書き文字データはMnistの2000サンプルです。org.deeplearning4j.datasets.iterator.impl.MnistDataSetIteratorクラスを使用することで簡単に手書き文字データを取得できます。
層の構成は以下のようにしています。DL4JのサンプルのLeNetをそのまま使っています。
入力層(784) --> 畳込み層(20) --> プーリング層 --> 畳込み層(50) --> プーリング層 --> 全結合層(500) --> 出力層(10)
学習し終わったら、一回評価してみて、ネットワークをファイルに保存します。logs/convolution.jsonにはネットワークの構造をjson形式で、logs/convolution.binには、ネットワークの重みパラメータの値をバイナリ形式で保存します。
次は読み込む方です。完全なソースコードはこちら。
public static void main(String[] args) throws Exception {
String confFile = "logs/convolution.json";
String binFile = "logs/convolution.bin";
int outputNum = 10;
Logger log = LoggerFactory.getLogger(CNNMnistReader.class);
log.info("Load stored model ...");
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(confFile)));
DataInputStream dis = new DataInputStream(new FileInputStream(new File(binFile)));
INDArray newParams = Nd4j.read(dis);
dis.close();
MultiLayerNetwork model = new MultiLayerNetwork(confFromJson);
model.init();
model.setParams(newParams);
System.out.println(model);
log.info("Evaluate weights....");
log.info("Evaluate model....");
MnistDataSetIterator testIter = new MnistDataSetIterator(100, 500);
Evaluation eval = new Evaluation(outputNum);
while (testIter.hasNext()) {
DataSet dataSet = testIter.next();
INDArray output = model.output(dataSet.getFeatureMatrix());
eval.eval(dataSet.getLabels(), output);
}
log.info(eval.stats());
}
ビルドスクリプト(build.gradle)
group 'jp.hashiwa.dl4j.sample'
version '1.0-SNAPSHOT'
apply plugin: 'application'
sourceCompatibility = 1.8
def isWin = 0 <= System.getProperty('os.name').indexOf('Windows')
repositories {
mavenCentral()
}
dependencies {
// testCompile group: 'junit', name: 'junit', version: '4.11'
compile 'org.deeplearning4j:deeplearning4j-core:0.4-rc3.8'
compile 'org.deeplearning4j:deeplearning4j-nlp:0.4-rc3.8'
compile 'org.deeplearning4j:deeplearning4j-ui:0.4-rc3.8'
compile 'com.google.guava:guava:19.0'
compile 'org.nd4j:nd4j-x86:0.4-rc3.8' // Nd4j Cpu
compile 'org.nd4j:canova-nd4j-image:0.0.0.14'
compile 'org.nd4j:canova-nd4j-codec:0.0.0.14'
// Support for reading and writing YAML-encoded data via Jackson abstractions.
compile 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.5.1'
}
task wrapper(type: Wrapper) {
gradleVersion = '2.13'
}
if (isWin) {
task pathingJar(type: Jar) {
dependsOn configurations.runtime
appendix = 'pathing'
doFirst {
manifest {
// Build the Class-Path for absolute paths based on runtime dependencies.
attributes "Class-Path": configurations.runtime.files.collect {
it.toURL().toString().replaceFirst(/file:\/+/, '/')
}.join(' ')
}
}
// assetCompile will be execute for all Jar-type tasks
// (see https://github.com/bertramdev/asset-pipeline/blob/master/asset-pipeline-gradle/src/main/groovy/asset/pipeline/gradle/AssetPipelinePlugin.groovy#L85)
// at least exclude the assets from pathing jar
exclude { it.file.absolutePath.contains('assetCompile') }
}
}
run {
jvmArgs '-server'
jvmArgs '-showversion'
if (project.hasProperty('main')) {
main(project.main)
}
if (isWin) {
dependsOn pathingJar
doFirst {
classpath = files("$buildDir/classes/main", "$buildDir/resources/main", "$projectDir/gsp-classes", pathingJar.archivePath)
}
}
}
dependenciesには、依存するArtifactを記載します。
基本的にはDL4JのQuick Startに掲載されているpom.xml(GitHubのここに全文が載ってます)のdependencyを全部記載しましたが、足りないやつは適当にMavenのリポジトリから探して記載しました。
また、Windowsで普通にgradle runすると、依存するjarが多すぎて CreateProcess error=206 が出るので、pathingJarというタスクで回避しています。(参考: https://github.com/grails/grails-core/issues/9125)
実行
>set JAVA_HOME="c:\Program Files\Java\jdk1.8.0_45"
>gradle run -Pmain=jp.hashiwa.dl4j.convolution.CNNMnistCreator
:compileJava UP-TO-DATE
:processResources UP-TO-DATE
:classes UP-TO-DATE
:pathingJar UP-TO-DATE
:run
java version "1.8.0_45"
Java(TM) SE Runtime Environment (build 1.8.0_45-b14)
Java HotSpot(TM) Server VM (build 25.45-b02, mixed mode)
21:51:08.148 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Load data....
21:51:08.753 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Build model....
4 29, 2016 9:51:08 午後 com.github.fommil.netlib.BLAS <clinit>
警告: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
4 29, 2016 9:51:09 午後 com.github.fommil.jni.JniLoader liberalLoad
情報: successfully loaded C:\Users\***\AppData\Local\Temp\jniloader6595959076149298041netlib-native_ref-win-i686.dll
****************************************************************
WARNING: COULD NOT LOAD NATIVE SYSTEM BLAS
ND4J performance WILL be reduced
Please install native BLAS library such as OpenBLAS or IntelMKL
See http://nd4j.org/getstarted.html#open for further details
****************************************************************
21:51:09.486 [main] DEBUG org.reflections.Reflections - going to scan these urls:
jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-api/0.4-rc3.8/8247bd513d454843df4daef4b730b901e5d0e7df/nd4j-api-0.4-rc3.8.jar!/
jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-jackson/0.4-rc3.8/620862141252f9f1f9e66873a583e005cf70f19f/nd4j-jackson-0.4-rc3.8.jar!/
jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-common/0.4-rc3.8/5b7fbd7ef0d20816706cadbdf88d8ed8dd609de3/nd4j-common-0.4-rc3.8.jar!/
jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-x86/0.4-rc3.8/ac82fac335e49d29588ce7b0bb241a3e8eb9ab85/nd4j-x86-0.4-rc3.8.jar!/
21:51:09.699 [main] DEBUG org.reflections.Reflections - could not scan file org/nd4j/linalg/cpu/javacpp/linux-x86_64/libjniLoop.so in url jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-x86/0.4-rc3.8/ac82fac335e49d29588ce7b0bb241a3e8eb9ab85/nd4j-x86-0.4-rc3.8.jar!/ with scanner SubTypesScanner
21:51:09.711 [main] DEBUG org.reflections.Reflections - could not scan file org/nd4j/linalg/cpu/javacpp/linux-x86_64/libjniLoop.so in url jar:file:/C:/Users/***/.gradle/caches/modules-2/files-2.1/org.nd4j/nd4j-x86/0.4-rc3.8/ac82fac335e49d29588ce7b0bb241a3e8eb9ab85/nd4j-x86-0.4-rc3.8.jar!/ with scanner TypeAnnotationsScanner
21:51:09.722 [main] INFO org.reflections.Reflections - Reflections took 219 ms to scan 4 urls, producing 116 keys and 359 values
21:51:10.353 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Train model....
21:51:10.545 [main] WARN o.d.optimize.solvers.BaseOptimizer - Objective function automatically set to minimize. Set stepFunction in neural net configuration to change default settings.
21:51:42.542 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 0 is 2.6776825639390944
21:52:43.315 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 2 is 2.1931480313950775
21:53:43.406 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 4 is 1.6746972585356235
21:54:43.500 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 6 is 1.3039268579930066
21:55:43.650 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 8 is 0.9798040338397026
21:56:43.773 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 0 is 1.0415607730567455
21:57:43.674 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 2 is 0.8497928621399403
21:58:43.376 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 4 is 0.7183761588913202
21:59:43.084 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 6 is 0.7029527971714735
22:00:42.679 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 8 is 0.5542249555617571
22:01:42.507 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 0 is 0.7950156704777478
22:02:42.284 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 2 is 0.6353060844421387
22:03:41.897 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 4 is 0.6403429657995702
22:04:41.944 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 6 is 0.5368056929969788
22:05:41.953 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 8 is 0.4195577607798576
22:06:41.745 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 0 is 0.3681572108781338
22:07:41.494 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 2 is 0.32245658259510995
22:08:41.221 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 4 is 0.2891946652776003
22:09:40.900 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 6 is 0.252537038629055
22:10:40.784 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 8 is 0.21910545122802258
22:11:10.775 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Evaluate weights....
22:11:10.777 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Evaluate model....
22:11:19.083 [main] INFO j.h.dl4j.convolution.CNNMnistCreator -
Examples labeled as 0 classified by model as 0: 37 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 1 classified by model as 1: 42 times
Examples labeled as 1 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 2: 31 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 3: 35 times
Examples labeled as 3 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 4: 38 times
Examples labeled as 4 classified by model as 6: 1 times
Examples labeled as 4 classified by model as 9: 1 times
Examples labeled as 5 classified by model as 4: 3 times
Examples labeled as 5 classified by model as 5: 40 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 6 classified by model as 0: 1 times
Examples labeled as 6 classified by model as 6: 40 times
Examples labeled as 6 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 2: 2 times
Examples labeled as 7 classified by model as 7: 44 times
Examples labeled as 7 classified by model as 9: 2 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 4: 1 times
Examples labeled as 8 classified by model as 6: 1 times
Examples labeled as 8 classified by model as 8: 28 times
Examples labeled as 8 classified by model as 9: 1 times
Examples labeled as 9 classified by model as 0: 1 times
Examples labeled as 9 classified by model as 4: 1 times
Examples labeled as 9 classified by model as 7: 2 times
Examples labeled as 9 classified by model as 9: 37 times
==========================Scores========================================
Accuracy: 0.93
Precision: 0.9317
Recall: 0.9293
F1 Score: 0.9305
========================================================================
22:11:19.210 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - Save model....
22:11:27.004 [main] INFO j.h.dl4j.convolution.CNNMnistCreator - ****************Example finished********************
BLASのインストールをサボっているので警告が出てますね。
サンプル数2000で、BLASなしで遅くなっているとはいえ、20分で学習が完了しました。スコアも9割くらい的中してるようです。また、Score at iteration ... の行を見ると、スコアが徐々に下がっているのがわかります。なお、ここではScoreIterationListenerを使っていますが、org.deeplearning4j.ui.weights.HistogramIterationListenerを使うと、スコアの変化などがグラフ化されてブラウザから見れるので面白いと思います。
ネットワークの構造と学習後のパラメータは、以下のファイルに保存されています。
>dir logs
2016/04/29 22:11 1,724,370 convolution.bin
2016/04/29 22:11 8,326 convolution.json
これを(CNNMnistReader.javaのように)読み込めば、後からでも学習後のネットワークを使用することができます。
今日はここまで。