可视化TFLite图并获取特定节点的中间值?

时间:2019-07-04 09:38:12

标签: tensorflow keras tensorflow-lite tensorflow2.0

我想知道是否有办法知道tflite中特定节点的输入和输出列表?我知道我可以获得输入/输出的详细信息,但这不允许我重构Interpreter内部发生的计算过程。所以我要做的是:

interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.get_tensor_details()

最后3个命令基本上给了我一些词典,这些词典似乎没有必要的信息。

所以我想知道是否有办法知道每个节点的输出在哪里? Interpreter当然知道这一点。我们可以吗?谢谢。

2 个答案:

答案 0 :(得分:1)

TF-Lite的机制使检查图形和获取内部节点的中间值的整个过程变得有些棘手。其他答案建议的get_tensor(...)方法无效。

如何可视化TF-Lite推理图?

可以使用visualize.py中的TensorFlow Lite repository脚本来可视化TensorFlow Lite模型。您只需要:

  • Clone the TensorFlow repository
  • 使用bazel运行visualize.py脚本:

    bazel run //tensorflow/lite/tools:visualize \
         model.tflite \
         visualized_model.html
    

我的TF模型中的节点在TF-Lite中是否具有等效的节点?

否!实际上,TF-Lite可以修改您的图形,使其变得更优化。以下是TF-Lite documentation中的一些相关内容:

  

TensorFlow Lite可以处理许多TensorFlow操作,即使它们没有直接的等效项。对于可以从图形中简单删除的操作(tf.identity),替换为张量(tf.placeholder)或融合为更复杂的操作(tf.nn.bias_add)就是这种情况。有时甚至可以通过这些过程之一来删除某些受支持的操作。

此外,TF-Lite API当前不允许获取节点对应关系。很难解释TF-Lite的内部格式。因此,即使没有下面的一个问题,也无法获得任何所需节点的中间输出。

我可以获取某些TF-Lite节点的中间值吗?

否!在这里,我将解释为什么get_tensor(...)在TF-Lite中不起作用。假设在内部表示中,图形包含3个张量,以及它们之间的一些密集操作(节点)(您可以将tensor1视为模型的输入,tensor3作为模型的输出)。在推断此特定图时,仅TF-Lite 需要2个缓冲区,让我们来演示一下。

首先,使用tensor1通过应用tensor2操作来计算dense。这仅需要2个缓冲区来存储值:

           dense              dense
[tensor1] -------> [tensor2] -------> [tensor3]
 ^^^^^^^            ^^^^^^^
 bufferA            bufferB

第二,使用存储在tensor2中的bufferB的值来计算tensor3 ...但是要等一下!我们不再需要bufferA,所以让我们用它来存储tensor3的值:

           dense              dense
[tensor1] -------> [tensor2] -------> [tensor3]
                    ^^^^^^^            ^^^^^^^
                    bufferB            bufferA

现在是棘手的部分。 tensor1的“输出值”仍将指向bufferA,该值现在保存tensor3的值。因此,如果您为第一个张量调用get_tensor(...),则会得到不正确的值。 documentation of this method甚至声明:

  

此功能不能用于读取中间结果。

如何解决这个问题?

  • 简便但受限制的方式。您可以指定节点的名称,以及要在转换期间获取其值的输出张量:

    tflite_convert \
        -- # other options of your model
        --output_arrays="output_node,intermediate/node/n1,intermediate/node/n2"
    
  • 硬而灵活的方式。您可以使用Bazel编译TF-Lite(使用this instruction)。然后,您实际上可以向文件Interpreter::Invoke()中的tensorflow/lite/interpreter.cc注入一些日志记录代码。丑陋的骇客,但行得通。

答案 1 :(得分:-2)

tf.lite.Interpreter提供了一个API,可以在张量级别(https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter#get_tensor)上窥探模型。因此,基本上,您需要了解节点的输入和输出的张量索引,然后在每次推断时调用“ get_tensor()”。