はじめに
本記事では、tensorflow2.0の学習済みモデルをKotlinで利用する方法を紹介します。
サンプルコードはKotlinのみですが、Javaでも同様の方法で動作するかと思います。
今回紹介する方法では、deeplearning4jというライブラリのKerasModelImportを利用します。
tensorflowにはJava APIが存在するので、deeplearning4j などというマイナーなライブラリ を使う必要は本来無いのですが、Java向けのtensorflow2.0ビルドは未だ配布されていないため、暫定的にdeeplearning4jを利用します。
(※自分でビルドすればtensorflow2.0対応のJava APIを利用できるかもしれません)
つまり、
Kotlin/JavaでDeepLearningの推論処理を動かしたい!
でも、deeplearning4jで学習コード書きたくない!!
Java向けのtensorflow2.0ビルドの配布を待てない!!!
こんな人のための繋ぎの対応策と捉えていただければと思います。
ライブラリのバージョン
tensorflowもdeeplearning4jもバージョンごとに変化が激しいので、バージョンが異なると挙動が変わる可能性が高いです。
tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7
学習コード(Python, tensorflow)
サンプルとして、tensorflow2.0でMNISTの学習を行うPythonコードを以下に載せます。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
# float64でないとdeeplearning4j上で正常動作しない場合があります
tf.keras.backend.set_floatx('float64')
# Sequential、SubClassedモデルはdeeplearning4jでインポートに失敗します
def make_functional_model(data_size, target_size):
inputs = Input(data_size)
fc1 = Dense(512, activation='relu')
fc2 = Dense(target_size, activation='softmax')
outputs = fc2(fc1(inputs))
return Model(inputs=inputs, outputs=outputs)
def train(dataset, model, criterion, optimizer):
for data, target in dataset:
with tf.GradientTape() as tape:
output = model(data)
loss = criterion(target, output)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
def main():
# MNISTの準備
mnist = tf.keras.datasets.mnist
(train_data, train_target), _ = mnist.load_data()
train_data = train_data / 255.0
train_data = np.reshape(train_data, (train_data.shape[0], -1))
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)
# モデルの準備
data_size = train_data.shape[1]
target_size = 10
model = make_functional_model(data_size, target_size)
# 学習
criterion = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
for epoch in range(5):
train(train_dataset, model, criterion, optimizer)
# Saved Modelはdeeplearning4jでインポートに失敗します
model.save('checkpoint.h5')
if __name__ == '__main__':
main()
学習コードは以上です。
なお、deeplearning4jで正常動作が確認できたのは、「Functionalモデルをhdf5形式で保存」した場合のみです。
Sequential、SubClassed形式でのモデル記述や、Saved Modelの保存形式ではインポートできませんでした。
推論コード(Kotlin, deeplearning4j)
続いて、推論コードです。
こちらはKotlinでdeeplearning4jを利用しています。
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport
fun main() {
val mnist = MnistDataSetIterator(1, false, 123)
val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
//val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")
mnist.forEach {
val data = it.features
val output = model.output(data)[0].toFloatVector()
val pred = output.indexOf(output.max()!!)
}
}
モデルをインポートして、output関数で推論するだけです。
とても簡単に利用できますね。
なお、Sequentialモデル用のインポート関数も存在しますが、前述したようにインポートに失敗します。
おわりに
本記事ではtensorflow2.0の学習済みモデルをKotlin/Javaで利用する方法を紹介しました。
tensorflowのJava APIが安定して利用できる日が来ることを祈りながら、今は耐えるのです。。
そして、みなさんもKotlinでDeepLearningしましょう!(これが一番言いたいこと)