最近,我尝试将模型(tf1.x)转换为save_model,并遵循官方的migrate document。但是,在我的用例中,我的手或张量流模型动物园中的大多数模型通常是pb文件,并且根据official document说
没有直接的方法可以将原始Graph.pb文件升级到TensorFlow 2.0,但是如果您有“冻结图”(将变量转换为常量的tf.Graph),则可以进行转换使用v1.wrap_function将其转换为concrete_function:
但是我仍然不明白如何转换为saved_model format。
答案 0 :(得分:2)
在TF1模式下:
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
def convert_pb_to_server_model(pb_model_path, export_dir, input_name='input:0', output_name='output:0'):
graph_def = read_pb_model(pb_model_path)
convert_pb_saved_model(graph_def, export_dir, input_name, output_name)
def read_pb_model(pb_model_path):
with tf.gfile.GFile(pb_model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def convert_pb_saved_model(graph_def, export_dir, input_name='input:0', output_name='output:0'):
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name(input_name)
out = g.get_tensor_by_name(output_name)
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"input": inp}, {"output": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
在TF2模式下:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
import numpy as np
def frozen_keras_graph(func_model):
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func_model)
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != tf.resource
]
output_tensors = frozen_func.outputs
graph_def = run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=get_grappler_config(["constfold", "function"]),
graph=frozen_func.graph)
return graph_def
def convert_keras_model_to_pb():
keras_model = train_model()
func_model = tf.function(keras_model).get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
graph_def = frozen_keras_graph(func_model)
tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')
def convert_saved_model_to_pb():
model_dir = '/tmp/saved_model'
model = tf.saved_model.load(model_dir)
func_model = model.signatures["serving_default"]
graph_def = frozen_keras_graph(func_model)
tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')
或者:
def convert_saved_model_to_pb(output_node_names, input_saved_model_dir, output_graph_dir):
from tensorflow.python.tools import freeze_graph
output_node_names = ','.join(output_node_names)
freeze_graph.freeze_graph(input_graph=None, input_saver=None,
input_binary=None,
input_checkpoint=None,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_dir,
clear_devices=None,
initializer_nodes=None,
input_saved_model_dir=input_saved_model_dir)
def save_output_tensor_to_pb():
output_names = ['StatefulPartitionedCall']
save_pb_model_path = '/tmp/pb_model/freeze_graph.pb'
model_dir = '/tmp/saved_model'
convert_saved_model_to_pb(output_names, model_dir, save_pb_model_path)
答案 1 :(得分:0)
为了确保我的理解是正确的,所以我还发布了我学到的东西:
如果有人想将tf1.x迁移到tf2.x,请先遵循official post。
在tensorflow 2.0中,tf.train.Saver和freeze_graph已被saved_model取代。
如果有人想将tb1.x的pb模型转换为save_model,则可以遵循@Boluoyu的回答。但是,如果您的运行时环境高于tf2.0,则可以使用以下代码:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
def covert_pb_to_server_model(pb_model_path, export_dir, input_name='input', output_name='output'):
graph_def = read_pb_model(pb_model_path)
covert_pb_saved_model(graph_def, export_dir, input_name, output_name)
def read_pb_model(pb_model_path):
with tf.gfile.GFile(pb_model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def covert_pb_saved_model(graph_def, export_dir, input_name='input', output_name='output'):
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name(input_name)
out = g.get_tensor_by_name(output_name)
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"input": inp}, {"output": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()