TensorFlow2.0中的XLA-冻结模型?

时间:2019-11-29 16:36:20

标签: tensorflow xla

我一直遵循有关XLA AOT编译的官方指南(https://www.tensorflow.org/xla/tfcompile),并且编译示例工作正常(在aot /测试内部)。

但是后来我想编译一些更大的模型,然后出现一个问题:如果XLA AOT需要将冻结的图作为输入(据我从指南中了解),并且TensorFlow 2中不再支持冻结的图,那么输入会做什么? XLA现在期望吗?

1 个答案:

答案 0 :(得分:0)

似乎在TensorFlow 2中仍然有冻结图的方法。我按照这篇文章创建了一个冻结图,并随后对其进行了编译:https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)