動作環境
TensorFlow 1.12.0を使用します。それ以前のバージョンではtflite_convertに成功しませんでした。
KerasのMobileNetはTensorFlow Liteモデルファイルに変換できない
KerasのMobileNetを使い、学習済みの重みを読み込み、SavedModelを作成します。
import tensorflow as tf
model = tf.keras.applications.MobileNet(
input_shape=(224,224,3),
alpha=0.5,weights=None, classes=101)
# 学習済みの重みを読み込む
model.load_weights("weight.hdf5")
# SavedModelを作成する
sess = tf.keras.backend.get_session()
tf.saved_model.simple_save(sess,"saved_model/",
inputs={'input': model.inputs[0]},
outputs={'output': model.outputs[0]})
tflite_convertコマンドでTensorFlow Liteモデルファイルに変換します。
tflite_convert --saved_model_dir saved_model/ --output_file graph.lite
すると大量のエラーメッセージが出てしまい変換できません。
b'2019-01-02 17:51:13.673615: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1080] Converting unsupported operation: ReadVariableOp\n2019-01-02 17:51:13.673671:
中略
Check failed: other_op->type == OperatorType::kMerge Found (Unsupported TensorFlow op: ReadVariableOp) as non-selected output from Switch, but only Merge supported.\nAborted (core dumped)\n'
None
モバイル向けのネットワークなのにモバイル向けに変換できません。
理由としてはTensorFlow LiteはTensorFlowに実装されているオペレータをすべて使えないためです。使えるオペレータは公式のTensorFlow Lite & TensorFlow Compatibility Guideというページで解説されています。
Kerasを書き換えてTensorFlow Liteモデルファイルに変換できるようにする
まずインストールされているKerasのソースコードと今回改造するKerasのソースコードを分離するために、今回改造するKerasのソースコードを今回のプロジェクトディレクトリにコピーします。改造元のTensorFlowは1.10.0を使用します。pipコマンドでインストールしたTensorFlowとは別にこちらからダウンロードしてください。それ以外のバージョンではインポートがうまくいきませんでした。
# TensorFlowは/tmp/ディレクトリで解凍したとする
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/applications/mobilenet.py ./
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/layers/normalization.py ./
cp /tmp/tensorflow-1.10.0/tensorflow/python/keras/applications/imagenet_utils.py ./
MobileNetではDropoutレイヤーを使っていますが、 Compatibility GuideではDropoutには対応していません。よって、コメントアウトします。
x = GlobalAveragePooling2D()(x)
x = Reshape(shape, name='reshape_1')(x)
# TensorFlow Liteでは非対応なのでコメントアウト
# x = Dropout(dropout, name='dropout')(x)
x = Conv2D(classes, (1, 1), padding='same', name='conv_preds')(x)
x = Activation('softmax', name='act_softmax')(x)
x = Reshape((classes,), name='reshape_2')(x)
エラーメッセージにはMergeがUnsupportedという内容がありました。MergeオペレータはBatchNormalizationレイヤー内のtf_utils.smart_cond関数で使われていました。この関数は第1引数がTrueならば第2引数の関数、Falseならば第3引数の関数を使うモデルを作成します。第1引数は学習ならばTrue予測ならばFalseで、TensorFlow Liteでは予測を行うので、Falseの場合のみの実装があればよいです。tf_utils.smart_cond関数を使わないように書き換えます。
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
epsilon=self.epsilon,
data_format=self._data_format)
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
#この関数は非対応のMergeオペレータを使用しているので使わない
#output, mean, variance = tf_utils.smart_cond(
# training, _fused_batch_norm_training, _fused_batch_norm_inference)
# 予測時に使用する方の関数を使う
output, mean, variance = _fused_batch_norm_inference()
どのようにしてDropoutとBatchNormalizationに原因があることを突き止めたかという解説は次の説で行います。
MobileNet呼び出し元とMobileNet内部でインポートするソースコードを変更します。
# 書き換えたカレントディレクトリのMobileNetをインポート
from mobilenet import MobileNet
model = MobileNet(
input_shape=(224,224,3),
alpha=0.5,weights=None, classes=101)
# インポートできないのでカレントディレクトリのimagenet_utilsモジュールを利用
# from tensorflow.python.keras.applications import imagenet_utils
# from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
# from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
import imagenet_utils
from imagenet_utils import _obtain_input_shape
from imagenet_utils import decode_predictions
# 改造したカレントディレクトリのBatchNormalizationを利用
# from tensorflow.python.keras.layers import BatchNormalization
from normalization import BatchNormalization
再びSavedModelの作成とtflite_converterによる変換を行うとエラーが発生せずにTensorFlow Lite モデルファイルを作成することができました。
使用しているオペレータを調査する方法
Kerasは使いやすい上位API群ですが、TensorFlowのオペレータを隠蔽していまいます。どのようなオペレータが使用されているかはTensorBoardで調査することができます。
まずTensorBoard用にログを出力します。
import tensorflow as tf
model = tf.keras.applications.MobileNet(
input_shape=(224,224,3),
alpha=0.5,weights=None, classes=101)
sess = tf.keras.backend.get_session()
writer = tf.summary.FileWriter("log", sess.graph)
writer.close()
TensorBoardを起動します。
tensorboard --logdir log
ブラウザでアクセスします。 http://localhost:6006/
ノードを開いていくことで、BatchNormalization付近でMergeオペレータを使用していることが分かりました。