パッケージ
tensorflow 2.6.0
keras 2.6.0
コード
convert.py
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
model = tf.keras.applications.vgg16.VGG16(weights='imagenet')
model.save("tf_vgg16", save_format='tf')
new_model = tf.keras.models.load_model('tf_vgg16')
full_model = tf.function(lambda inputs: new_model(inputs, training=False))
# カスタムレイヤーがある場合は、例えば以下のように書く
# new_model = tf.keras.models.load_model(model_path, custom_objects={"DogsAndCatsModel": DogsAndCatsModel})
# inputsのshape/dtypeをセットする
frozen_inputs = []
for in_ in new_model.inputs:
frozen_inputs.append(tf.TensorSpec(in_.shape, in_.dtype))
full_model = full_model.get_concrete_function(*frozen_inputs)
frozen_func = convert_variables_to_constants_v2(full_model)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir='log', name='tf_vgg16.pb', as_text=False)
※convert.pyと同じディレクトリに空のlog ディレクトリを用意しておく