LoginSignup
0
0

More than 1 year has passed since last update.

kerasの学習済みモデルをpbファイルで出力する

Last updated at Posted at 2022-07-04

パッケージ

tensorflow 2.6.0
keras 2.6.0

コード

convert.py
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

model = tf.keras.applications.vgg16.VGG16(weights='imagenet')
model.save("tf_vgg16", save_format='tf')

new_model = tf.keras.models.load_model('tf_vgg16')

full_model = tf.function(lambda inputs: new_model(inputs, training=False))
# カスタムレイヤーがある場合は、例えば以下のように書く
# new_model = tf.keras.models.load_model(model_path,  custom_objects={"DogsAndCatsModel": DogsAndCatsModel})

# inputsのshape/dtypeをセットする
frozen_inputs = []
for in_ in new_model.inputs:
    frozen_inputs.append(tf.TensorSpec(in_.shape, in_.dtype))

full_model = full_model.get_concrete_function(*frozen_inputs)
frozen_func = convert_variables_to_constants_v2(full_model)

tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir='log', name='tf_vgg16.pb', as_text=False)

※convert.pyと同じディレクトリに空のlog ディレクトリを用意しておく

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0