LoginSignup
4
2

More than 3 years have passed since last update.

tensorflow2.0の学習済みモデルをKotlin/Javaで利用する方法

Last updated at Posted at 2020-07-23

はじめに

本記事では、tensorflow2.0の学習済みモデルをKotlinで利用する方法を紹介します。
サンプルコードはKotlinのみですが、Javaでも同様の方法で動作するかと思います。

今回紹介する方法では、deeplearning4jというライブラリのKerasModelImportを利用します。
tensorflowにはJava APIが存在するので、deeplearning4j などというマイナーなライブラリ を使う必要は本来無いのですが、Java向けのtensorflow2.0ビルドは未だ配布されていないため、暫定的にdeeplearning4jを利用します。
(※自分でビルドすればtensorflow2.0対応のJava APIを利用できるかもしれません)

つまり、

Kotlin/JavaでDeepLearningの推論処理を動かしたい!
でも、deeplearning4jで学習コード書きたくない!!
Java向けのtensorflow2.0ビルドの配布を待てない!!!

こんな人のための繋ぎの対応策と捉えていただければと思います。

ライブラリのバージョン

tensorflowもdeeplearning4jもバージョンごとに変化が激しいので、バージョンが異なると挙動が変わる可能性が高いです。

tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7

学習コード(Python, tensorflow)

サンプルとして、tensorflow2.0でMNISTの学習を行うPythonコードを以下に載せます。

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

# float64でないとdeeplearning4j上で正常動作しない場合があります
tf.keras.backend.set_floatx('float64')

# Sequential、SubClassedモデルはdeeplearning4jでインポートに失敗します
def make_functional_model(data_size, target_size):
    inputs = Input(data_size)
    fc1 = Dense(512, activation='relu')
    fc2 = Dense(target_size, activation='softmax')
    outputs = fc2(fc1(inputs))
    return Model(inputs=inputs, outputs=outputs)

def train(dataset, model, criterion, optimizer):
    for data, target in dataset:
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(target, output)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

def main():

    # MNISTの準備
    mnist = tf.keras.datasets.mnist
    (train_data, train_target), _ = mnist.load_data()
    train_data = train_data / 255.0
    train_data = np.reshape(train_data, (train_data.shape[0], -1))
    train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)

    # モデルの準備
    data_size = train_data.shape[1]
    target_size = 10
    model = make_functional_model(data_size, target_size)

    # 学習
    criterion = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(5):
        train(train_dataset, model, criterion, optimizer)

    # Saved Modelはdeeplearning4jでインポートに失敗します
    model.save('checkpoint.h5')

if __name__ == '__main__':
    main()

学習コードは以上です。
なお、deeplearning4jで正常動作が確認できたのは、「Functionalモデルをhdf5形式で保存」した場合のみです。
Sequential、SubClassed形式でのモデル記述や、Saved Modelの保存形式ではインポートできませんでした。

推論コード(Kotlin, deeplearning4j)

続いて、推論コードです。
こちらはKotlinでdeeplearning4jを利用しています。

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport

fun main() {

    val mnist = MnistDataSetIterator(1, false, 123)
    val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
    //val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")

    mnist.forEach {
        val data = it.features
        val output = model.output(data)[0].toFloatVector()
        val pred = output.indexOf(output.max()!!)
    }
}

モデルをインポートして、output関数で推論するだけです。
とても簡単に利用できますね。
なお、Sequentialモデル用のインポート関数も存在しますが、前述したようにインポートに失敗します。

おわりに

本記事ではtensorflow2.0の学習済みモデルをKotlin/Javaで利用する方法を紹介しました。
tensorflowのJava APIが安定して利用できる日が来ることを祈りながら、今は耐えるのです。。

そして、みなさんもKotlinでDeepLearningしましょう!(これが一番言いたいこと)

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2