LoginSignup
16
9

More than 5 years have passed since last update.

KerasのMobileNetモデルをTensorFlow Liteモデルファイルに変換できるように書き換える

Last updated at Posted at 2019-01-03

動作環境

TensorFlow 1.12.0を使用します。それ以前のバージョンではtflite_convertに成功しませんでした。

KerasのMobileNetはTensorFlow Liteモデルファイルに変換できない

KerasのMobileNetを使い、学習済みの重みを読み込み、SavedModelを作成します。

keras2saved_model.py
import tensorflow as tf

model = tf.keras.applications.MobileNet(
    input_shape=(224,224,3),
    alpha=0.5,weights=None, classes=101)
# 学習済みの重みを読み込む
model.load_weights("weight.hdf5")
# SavedModelを作成する
sess = tf.keras.backend.get_session()
tf.saved_model.simple_save(sess,"saved_model/",
    inputs={'input': model.inputs[0]},
    outputs={'output': model.outputs[0]})

tflite_convertコマンドでTensorFlow Liteモデルファイルに変換します。

make_tflite_model.sh
tflite_convert  --saved_model_dir saved_model/ --output_file graph.lite

すると大量のエラーメッセージが出てしまい変換できません。

b'2019-01-02 17:51:13.673615: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1080] Converting unsupported operation: ReadVariableOp\n2019-01-02 17:51:13.673671: 

中略

Check failed: other_op->type == OperatorType::kMerge Found (Unsupported TensorFlow op: ReadVariableOp) as non-selected output from Switch, but only Merge supported.\nAborted (core dumped)\n'
None

モバイル向けのネットワークなのにモバイル向けに変換できません。

理由としてはTensorFlow LiteはTensorFlowに実装されているオペレータをすべて使えないためです。使えるオペレータは公式のTensorFlow Lite & TensorFlow Compatibility Guideというページで解説されています。

Kerasを書き換えてTensorFlow Liteモデルファイルに変換できるようにする

まずインストールされているKerasのソースコードと今回改造するKerasのソースコードを分離するために、今回改造するKerasのソースコードを今回のプロジェクトディレクトリにコピーします。改造元のTensorFlowは1.10.0を使用します。pipコマンドでインストールしたTensorFlowとは別にこちらからダウンロードしてください。それ以外のバージョンではインポートがうまくいきませんでした。

# TensorFlowは/tmp/ディレクトリで解凍したとする
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/applications/mobilenet.py ./
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/layers/normalization.py ./
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/applications/imagenet_utils.py ./

MobileNetではDropoutレイヤーを使っていますが、 Compatibility GuideではDropoutには対応していません。よって、コメントアウトします。

mobilenet.py
    x = GlobalAveragePooling2D()(x)
    x = Reshape(shape, name='reshape_1')(x)
    # TensorFlow Liteでは非対応なのでコメントアウト
    # x = Dropout(dropout, name='dropout')(x)
    x = Conv2D(classes, (1, 1), padding='same', name='conv_preds')(x)
    x = Activation('softmax', name='act_softmax')(x)
    x = Reshape((classes,), name='reshape_2')(x)

エラーメッセージにはMergeがUnsupportedという内容がありました。MergeオペレータはBatchNormalizationレイヤー内のtf_utils.smart_cond関数で使われていました。この関数は第1引数がTrueならば第2引数の関数、Falseならば第3引数の関数を使うモデルを作成します。第1引数は学習ならばTrue予測ならばFalseで、TensorFlow Liteでは予測を行うので、Falseの場合のみの実装があればよいです。tf_utils.smart_cond関数を使わないように書き換えます。

normalization.py
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)
    #この関数は非対応のMergeオペレータを使用しているので使わない
    #output, mean, variance = tf_utils.smart_cond(
    #    training, _fused_batch_norm_training, _fused_batch_norm_inference)
        # 予測時に使用する方の関数を使う
    output, mean, variance = _fused_batch_norm_inference()

どのようにしてDropoutとBatchNormalizationに原因があることを突き止めたかという解説は次の説で行います。

MobileNet呼び出し元とMobileNet内部でインポートするソースコードを変更します。

model.py
# 書き換えたカレントディレクトリのMobileNetをインポート
from mobilenet import MobileNet
model = MobileNet(
    input_shape=(224,224,3),
    alpha=0.5,weights=None, classes=101)
mobilenet.py
# インポートできないのでカレントディレクトリのimagenet_utilsモジュールを利用
# from tensorflow.python.keras.applications import imagenet_utils
# from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
# from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
import imagenet_utils
from imagenet_utils import _obtain_input_shape
from imagenet_utils import decode_predictions
# 改造したカレントディレクトリのBatchNormalizationを利用
#from tensorflow.python.keras.layers import BatchNormalization
from normalization import BatchNormalization

再びSavedModelの作成とtflite_converterによる変換を行うとエラーが発生せずにTensorFlow Lite モデルファイルを作成することができました。

使用しているオペレータを調査する方法

Kerasは使いやすい上位API群ですが、TensorFlowのオペレータを隠蔽していまいます。どのようなオペレータが使用されているかはTensorBoardで調査することができます。

まずTensorBoard用にログを出力します。

keras2log.py
import tensorflow as tf

model = tf.keras.applications.MobileNet(
    input_shape=(224,224,3),
    alpha=0.5,weights=None, classes=101)
sess = tf.keras.backend.get_session()
writer = tf.summary.FileWriter("log", sess.graph)
writer.close()

TensorBoardを起動します。

tensorboard  --logdir log

ブラウザでアクセスします。 http://localhost:6006/

tensorboard.png

ノードを開いていくことで、BatchNormalization付近でMergeオペレータを使用していることが分かりました。

全体ソースコード

TensorFlow Lite モデルファイルの作成はこちらになります。
それを使用したデモアプリがこちらになります。

16
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
9