我想知道是否有办法知道tflite中特定节点的输入和输出列表?我知道我可以获得输入/输出的详细信息,但这不允许我重构Interpreter
内部发生的计算过程。所以我要做的是:
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.get_tensor_details()
最后3个命令基本上给了我一些词典,这些词典似乎没有必要的信息。
所以我想知道是否有办法知道每个节点的输出在哪里? Interpreter
当然知道这一点。我们可以吗?谢谢。
答案 0 :(得分:1)
TF-Lite的机制使检查图形和获取内部节点的中间值的整个过程变得有些棘手。其他答案建议的get_tensor(...)
方法无效。
可以使用visualize.py中的TensorFlow Lite repository脚本来可视化TensorFlow Lite模型。您只需要:
使用bazel运行visualize.py
脚本:
bazel run //tensorflow/lite/tools:visualize \
model.tflite \
visualized_model.html
否!实际上,TF-Lite可以修改您的图形,使其变得更优化。以下是TF-Lite documentation中的一些相关内容:
TensorFlow Lite可以处理许多TensorFlow操作,即使它们没有直接的等效项。对于可以从图形中简单删除的操作(tf.identity),替换为张量(tf.placeholder)或融合为更复杂的操作(tf.nn.bias_add)就是这种情况。有时甚至可以通过这些过程之一来删除某些受支持的操作。
此外,TF-Lite API当前不允许获取节点对应关系。很难解释TF-Lite的内部格式。因此,即使没有下面的一个问题,也无法获得任何所需节点的中间输出。
否!在这里,我将解释为什么get_tensor(...)
在TF-Lite中不起作用。假设在内部表示中,图形包含3个张量,以及它们之间的一些密集操作(节点)(您可以将tensor1
视为模型的输入,tensor3
作为模型的输出)。在推断此特定图时,仅TF-Lite 需要2个缓冲区,让我们来演示一下。
首先,使用tensor1
通过应用tensor2
操作来计算dense
。这仅需要2个缓冲区来存储值:
dense dense
[tensor1] -------> [tensor2] -------> [tensor3]
^^^^^^^ ^^^^^^^
bufferA bufferB
第二,使用存储在tensor2
中的bufferB
的值来计算tensor3
...但是要等一下!我们不再需要bufferA
,所以让我们用它来存储tensor3
的值:
dense dense
[tensor1] -------> [tensor2] -------> [tensor3]
^^^^^^^ ^^^^^^^
bufferB bufferA
现在是棘手的部分。 tensor1
的“输出值”仍将指向bufferA
,该值现在保存tensor3
的值。因此,如果您为第一个张量调用get_tensor(...)
,则会得到不正确的值。 documentation of this method甚至声明:
此功能不能用于读取中间结果。
简便但受限制的方式。您可以指定节点的名称,以及要在转换期间获取其值的输出张量:
tflite_convert \
-- # other options of your model
--output_arrays="output_node,intermediate/node/n1,intermediate/node/n2"
硬而灵活的方式。您可以使用Bazel编译TF-Lite(使用this instruction)。然后,您实际上可以向文件Interpreter::Invoke()
中的tensorflow/lite/interpreter.cc
注入一些日志记录代码。丑陋的骇客,但行得通。
答案 1 :(得分:-2)
tf.lite.Interpreter提供了一个API,可以在张量级别(https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter#get_tensor)上窥探模型。因此,基本上,您需要了解节点的输入和输出的张量索引,然后在每次推断时调用“ get_tensor()”。