为什么在TF2中不推荐使用convert_variables_to_constants()?

时间:2019-07-02 16:53:57

标签: c python-3.x tensorflow tensorflow2.0

正如标题所述,为什么在tensorflow 2中不建议使用convert_variables_to_constants()?要获得可保存的模型以加载到下游独立应用程序中进行推理(在我的情况下,使用C API),最简单的替换方法是

1 个答案:

答案 0 :(得分:0)

在TF 2.x中没有tf.Session(),这是在TF 1.x和TF 2.0中构建冻结模型的必要组件。

根据TensorFlow 2.0.0 release description,“删除了Frozen_graph命令行工具;应使用SavedModel代替冻结图。”因此,应该只使用SavedModel

但是,如果您仍然需要冻结的图形,则

# Save model to SavedModel format
tf.saved_model.save(model, "./models/simple_model")

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]

然后将其保存为冻结图。

注意:您现在应该使用TF 1.x加载此冻结图。 功能

tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="simple_frozen_graph.pb",
                  as_text=False)

然后要加载此模型( TF 1.x 代码)-

with tf.io.gfile.GFile("./frozen_models/simple_frozen_graph.pb", "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    loaded = graph_def.ParseFromString(f.read())

减少freeze_graph的等待时间对于应用程序而言可能非常重要,而存储在SavedModel中的高精度权重可能会成为问题。但是也有一些简单的方法可以解决这个问题,这超出了这个问题的范围。