24
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasで作ったモデルをUnityに持っていくときのハマりどころ

Posted at

はじめに

Unityでは、ゲーム内で強化学習させるならml-agentsとかKelpNetなどを使えますが、一方でゲーム中に得たデータを保存し、別の環境で機械学習させた後に学習結果をUnityにもっていく、という方法もあります。

そういう方法をDeep Neural Networkで行う場合はKerasが便利です。ネットワークを設計するのも簡単ですし、Google Colaboratoryにはデフォルトで入ってるので環境構築に悩むことなくすぐに作業できます。

ただ、Kerasで作ったモデルをUnity、特にOculus GoなどのAndroidデバイスに持っていくときにハマりどころがいくつかあります。しかも、ネットで見られる情報が不完全で試しても上手くいかないことが多いです。この記事では、そのハマりどころの紹介と回避方法を紹介します。ただし、使っている関数のいくつかがdeprecatedなので、そのうち修整が必要になるでしょう。問題が発生しましたら記事へのコメントでお知らせください。なお、記事で使用しているUnityのバージョンは2018.3.8f1です。

Kerasモデルの準備

こちらのチュートリアル の内容を実装します。開発環境はGoogle Colaboratoryです。

まず、こちらのデータをダウンロードして、pima-indians-diabetes.data.csvというファイル名でGoogle Colaboratoryのドライブにアップロードしてください。

次に、Keras Functional APIでモデルと学習プロセスを定義します。ここでinputは"input_x",outputは"output_y"と定義していることに注意してください。

from keras.models import Sequential
from keras.layers import Dense
import numpy

numpy.random.seed(7)

dataset = numpy.loadtxt("pima-indians-diabetes.data.csv", delimiter=",")

X = dataset[:,0:8]
Y = dataset[:,8]

from keras.layers import Input
from keras.models import Model

inputs = Input(shape=(8,),name='input_x')
x = Dense(12, activation='relu')(inputs)
x = Dense(8, activation='relu')(x)
predictions = Dense(1, activation='sigmoid',name='output_y')(x)
model = Model(input=inputs, output=predictions)

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, Y, epochs=150, batch_size=10)
scores = model.evaluate(X, Y)
print("\n%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

学習が終了したら、

print(model.input)
print(model.output)

でinputとoutputの情報を見てください。

Tensor("input_x:0", shape=(?, 8), dtype=float32)
Tensor("output_y/Sigmoid:0", shape=(?, 1), dtype=float32)

上記のように、inputのほうは"input_x"と元の名前の通りですが、outputの方は"/Sigmoid"が追加されてます。この追加されてる方の名前をUnityで使うので注意してください。

#KerasのモデルをTensorFlowグラフとして出力する
Kerasのモデルは直接Unityで読めないので、一旦TensorFlowの形式に変換します。ネットでググると変換の方法がいくつも出てきますが、ほとんどの方法はUnityに持っていったときに上手くいかなかったです。ここでは、こちらの記事の方法を紹介します。以下のコードはほぼ元記事の通りですが、そのままだとエラーが出るので一部修正しています。

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from keras import backend as K

ksess = K.get_session()

K.set_learning_phase(0)
graph = ksess.graph

num_output = 1
prefix = "output"
pred = [None]*num_output
outputName = [None]*num_output
for i in range(num_output):
    outputName[i] = prefix + str(i)
    pred[i] = tf.identity(model.get_output_at(i), name=outputName[i])

constant_graph = graph_util.convert_variables_to_constants(ksess, ksess.graph.as_graph_def(), outputName)

コードの内容を見るとわかりますが、Kerasのセッションをget_session()でTensorFlowのセッションとして呼び出し、学習したモデル内の変数をgraph_util.convert_variables_to_constantsで定数に変換しています。

output_dir = "./"
output_graph_name = "keras2tf.pb"
graph_io.write_graph(constant_graph, output_dir, output_graph_name, as_text=False)

これで、keras2tf.pbというファイルにTensorflowグラフが保存されました。このファイルをローカルPCにダウンロードして、拡張子を.bytesに変更してください。

#Unity内での作業
ここからはUnity内での作業になります。以下の作業はOculus Go向けのものですが、他のAndroidデバイスでも同様の流れで大丈夫かと思います。

Unityで新しいプロジェクトを作ったら、Androidをビルドターゲットにし、Player SettingでAPI Lebelを25に、Scripting Define SymbolをENABLE_TENSORFLOWにします。

つぎに、TensorFlowSharp Unityプラグインをインポートしてください。こちらからダウンロードできます。

image.png

ImportするとPlugins/Androidというフォルダが作られますが、その中の"System."と最初に名前のつくファイルは全て削除してください。そうしないとビルド時にエラーが出ます。どうやら、Unity2018.3で出るエラーのようです

あと、上で作ったkeras2tf.bytesをAssetsフォルダの下にドラッグアンドドロップしてください。

次にModelImportExample.csというファイルを作り、以下のスクリプトを入力してください。スクリプト作成にはこちらの記事を参考にしました。

ModelImportExample.cs

using UnityEngine;
using TensorFlow;

public class ModelImportExample : MonoBehaviour
{
    public TextAsset model;
    private float[,] inputTensor = new float[1, 8];
    private float[] testData = new float[] { 6f, 148f, 72f, 35f, 0f, 33.6f, 0.627f, 50f };

    void Start()
    {
        #if UNITY_ANDROID && !UNITY_EDITOR
                TensorFlowSharp.Android.NativeBinding.Init();
        #endif

        TFGraph graph = new TFGraph();
        graph.Import(model.bytes);
        TFSession sess = new TFSession(graph);

        for (int i = 0; i < 8; i++)
        {
            inputTensor[0, i] = testData[i];
        }

        TFTensor input = inputTensor;

        var runner = sess.GetRunner();
        var test = runner.AddInput(graph["input_x"][0], input);
        test.Fetch(graph["output_y/Sigmoid"][0]);
        var output = runner.Run();

        var result = output[0].GetValue() as float[,];

        Debug.Log(result[0,0]);
        
    }
}

重要なポイントは、まず、Android向けにビルドするときは

#if UNITY_ANDROID && !UNITY_EDITOR
   TensorFlowSharp.Android.NativeBinding.Init();
#endif

を追加してください。無いとビルドできません。

つぎに、インポートしたkeras2tf.bytesはTextAssetとしてエディタ上でmodelに割り当ててください。

image.png

あと、グラフへの入力と出力は

   var test = runner.AddInput(graph["input_x"][0], input);
   test.Fetch(graph["output_y/Sigmoid"][0]);

で定義しています。それぞれの名前は、Keras側で確認した通り"input_x", "output_y/Sigmoid"になっていることに注意してください。

これでエディタ上でスクリプトを実行すると、コンソールに"0.9049003"とアウトプットが出てくるはずです。確認したら、ビルドしてください。正常に終了するはずです。Oculus Goでも動作確認済みです。

24
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?