Help us understand the problem. What is going on with this article?

ゼロから作るDeep Learning Java編 第1章 はじめに

More than 1 year has passed since last update.

目次

はじめに

「ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装」(オライリー・ジャパン 斎藤 康毅 著)はDeep Learningを初歩から解説した本です。プログラミング言語としてPythonを使用していますが、本投稿は(およびこれに続く一連の投稿では)これをJavaで実装します。ただしゼロから作るDeep Learningを置き換えるものではなく、併読することを前提として記述しています。

使用する外部ライブラリ

Numpy

ゼロから作るDeep Learningでは数値計算のためのライブラリとしてNumPyを使用しています。本投稿ではND4Jを使用します。ND4JはNumPyやMatlibに似た機能を持つJava用のライブラリです。オープンソースのDeeplearning4Jで使用されていることで有名です。ただしドキュメントはそれほど整備されていません。多少、試行錯誤して使い方を習得する必要があります。
ND4JはCUDAを経由してGPUを使用することができます。高速化するためには有利であると考えます。

Matplotlib

ゼロから作るDeep Learningではグラフ描画のためのライブラリとしてMatplotlibを使用しています。Deep Learningを実現する上でグラフは必須ではないため、本投稿ではグラフ描画のためのライブラリは使用しません。

進め方

ゼロから作るDeep Learningに登場するPythonで記述されたプログラムを逐次Javaで書き直していきます。結果が同じになることを確認するため、JUnitのコードとして記述します。例えば以下のような感じです。

public class C1_5_NumPy {

    @Test
    public void C1_5_2_NumPy配列の生成() {
        INDArray x = Nd4j.create(new double[] {1.0, 2.0, 3.0});
        assertEquals("[1.00,2.00,3.00]", Util.string(x));
    }
}

これは「1 Python入門 1.5 NumPy 1.5.2 NumPy配列の生成」に登場するサンプルプログラムをJavaで記述したものです。クラス名やメソッド名に日本語を使用しているので、環境によっては動作しないかもしれません。私の環境(Windows10 + Eclipse Oxygen.2 Release 4.7.2)では問題なく動きます。
Pythonのコードが記述されている節についてのみ記述しているので見出しの項番は連番になっていません。Javaで記述したすべてのプログラムはGitHubのsaka1029/Deep.Learningに掲載しています。

環境設定

GitHub上のプロジェクトをビルドするためには以下の依存関係をpom.xmlに設定する必要があります。

  <dependencies>
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>0.9.1</version>
    </dependency>
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-log4j12</artifactId>
        <version>1.7.2</version>
    </dependency>
  </dependencies>

またIDE上のプロジェクトの設定でJava8を使用するように設定しておく必要があります。Java9とすると実行時にクラスローダが例外をスローすることがあります。

1.5 ND4J

ゼロから作るDeep Learningでは第1章でPython、Numpy、Matplotlibについて説明していますが、本投稿ではND4Jについてのみ説明します。

1.5.2 ND4Jの配列の生成

ND4Jでは配列の型はINDArrayになります。初期値を与えて初期化する場合はNd4jクラスのファクトリーメソッドを使って以下のようにします。Util.string(INDArray)は配列の内容を文字列化する自作のユーティリティ関数です。

INDArray x = Nd4j.create(new double[] {1.0, 2.0, 3.0});
assertEquals("[1.00,2.00,3.00]", Util.string(x));

1.5.3 ND4Jの算術計算

1次元配列の四則演算は以下のようにします。

INDArray x = Nd4j.create(new double[] {1.0, 2.0, 3.0});
INDArray y = Nd4j.create(new double[] {2.0, 4.0, 6.0});
assertEquals("[3.00,6.00,9.00]", Util.string(x.add(y)));
assertEquals("[-1.00,-2.00,-3.00]", Util.string(x.sub(y)));
assertEquals("[2.00,8.00,18.00]", Util.string(x.mul(y)));
assertEquals("[0.50,0.50,0.50]", Util.string(x.div(y)));

INDArray.mul(INDArray)は要素ごとの掛け算で、行列積はINDArray.mmul(INDArray)です。
INDArrayにはこれ以外にもadd(Number)、sub(Number)、mul(Number)、div(Number)がオーバーロードされているので、スカラー値との四則演算も同様に行うことができます。
xの実体は1行3列の2次元配列となります。この点はNumPyと異なるので注意が必要です。INDArray.rank()は配列の次元を返します。

assertArrayEquals(new int[] {1,3}, x.shape());
assertEquals(2, x.rank());

1.5.4 ND4JのN次元配列

2次元以上になってもINDArrayを使う点はNumPyと同じ考え方です。初期値を与えて2次元配列を作成する場合は以下のようにします。

INDArray A = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
assertEquals("[[1.00,2.00],[3.00,4.00]]", Util.string(A));
assertArrayEquals(new int[] {2,2}, A.shape());
assertEquals(2, A.rank());

1.5.5 ブロードキャスト

NumPyでは次元の異なる配列どうしをそのまま四則演算することが可能ですが、ND4Jではそれはできません。INDArray.broadcast(int[])メソッドを使って明示的に次元を合わせてやる必要があります。
これを忘れるとIllegalStateExceptionがスローされます。

INDArray A = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
INDArray B = Nd4j.create(new double[] {10, 20});
// INDArrayのメソッドmul(INDArray)やadd(INDArray)は自動的にブロードキャストしません。
// broadcast(int[])を使用して左辺の次元に合わせてやる必要があります。
assertEquals("[[10.00,40.00],[30.00,80.00]]", Util.string(A.mul(B.broadcast(A.shape()))));
// 単純に掛け算するとIllegalStateException: Mis matched shapesとなります。
try {
    assertEquals("[[10.00,40.00],[30.00,80.00]]", Util.string(A.mul(B)));
    fail();
} catch (IllegalStateException e) {
    assertEquals("Mis matched shapes", e.getMessage());
}
// あるいはmulRowVector(INDArray)を使うこともできます。
assertEquals("[[10.00,40.00],[30.00,80.00]]", Util.string(A.mulRowVector(B)));

1.5.6 要素へのアクセス

INDArrayはIterableインタフェースを実装していません。またIteratorを返すメソッドなどもありません。要素へのアクセスは添え字を使って行う必要があります。
要素を取り出すときはINDArray.getDouble(int...)、行を取り出すときはINDArray.getRow(int)、列を取り出すときはINDArray.getColumn(int)を使います。2次元配列を1次元に変化する場合はNd4j.toFlattened(INDArray)メソッドを使います。

INDArray X = Nd4j.create(new double[][] {{51, 55}, {14, 19}, {0, 4}});
assertEquals("[[51.00,55.00],[14.00,19.00],[0.00,4.00]]", Util.string(X));
assertEquals("[51.00,55.00]", Util.string(X.getRow(0)));
assertEquals(55.0, X.getDouble(0, 1), 5e-6);
// INDArrayはIterableインタフェースを実装していません。
for (int i = 0, size = X.size(0); i < size; ++i)
    assertEquals(2, X.getRow(i).size(1));
// Xをベクトルに変換します。
X = Nd4j.toFlattened(X);
assertEquals("[51.00,55.00,14.00,19.00,0.00,4.00]", Util.string(X));
// 指定した要素をすべて取り出すこともできます。
assertEquals("[51.00,14.00,0.00]", Util.string(X.getColumns(0, 2, 4)));

データ型

今までの例を見るとdoubleの配列を使って初期化しているので、INDArrayは内部でdoubleの配列を保持しているように見えます。しかしND4Jはデフォルトではfloatの配列を確保します。ドキュメントには以下の記述があります。

データ型の設定
ND4Jは現在、float精度値またはdouble精度値によるINDArrayによるバッキングを許可しています。デフォルトは単精度(float)です。ND4Jがdouble精度を配列全体に使用するように設定するには、以下を使用することができます。
0.4-rc3.8、及びそれ以前の場合、
Nd4j.dtype = DataBuffer.Type.DOUBLE;
NDArrayFactory factory = Nd4j.factory();
factory.setDType(DataBuffer.Type.DOUBLE);
0.4-rc3.9、及びそれ以降の場合、
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);

DataTypeUtilを使ってdoubleの配列を作成してみます。

// デフォルト精度の配列を作成します。
INDArray a = Nd4j.create(new double[] {1D / 3});
// 倍精度(double)に設定します。
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
// doubleに変更されていることがわかります。
assertEquals(DataBuffer.Type.DOUBLE, DataTypeUtil.getDtypeFromContext());
// 倍精度の配列を作成します。
INDArray b = Nd4j.create(new double[] {1D / 3});
// aはdoubleで初期化しましたが単精度の配列です。
assertEquals(0.3333333432674408, a.getDouble(0), 5e-14);
// bはdoubleで初期化しましたが倍精度の配列です。
assertEquals(0.3333333333333333, b.getDouble(0), 5e-14);
// 単精度(float)に戻します。
DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);

ND4Jがデフォルトでfloatを使用するのは主としてDeep Learningでの利用を想定しているからだと思います。ND4JはCUDAを経由してGPUを使用することができますが、その場合もfloatの方がハンドリングがしやすいのでしょう。
この一連の投稿では一貫してdouble型を使いますが、INDArrayの内部はfloat型である点に注意してください。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away