我在Tensor flow Camera Demo中使用自定义模型进行分类。 我生成了一个.bp文件,我可以显示它包含的巨大图表。 要将此图转换为优化图,如[https://www.oreilly.com/learning/tensorflow-on-android]中所示,可以使用以下过程:
$ bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input=tf_files/retrained_graph.pb \
--output=tensorflow/examples/android/assets/retrained_graph.pb
--input_names=Mul \
--output_names=final_result
这里是如何从图形显示中找到input_names和output_names。 当我不使用专有名称时,我会遇到设备崩溃:
E/TensorFlowInferenceInterface(16821): Failed to run TensorFlow inference
with inputs:[AvgPool], outputs:[predictions]
E/AndroidRuntime(16821): FATAL EXCEPTION: inference
E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible
shapes: [1,224,224,3] vs. [32,1,1,2048]
E/AndroidRuntime(16821): [[Node: dropout/dropout/mul = Mul[T=DT_FLOAT,
_device="/job:localhost/replica:0/task:0/cpu:0"](dropout/dropout/div,
dropout/dropout/Floor)]]
答案 0 :(得分:16)
试试这个:
运行python
>>> import tensorflow as tf
>>> gf = tf.GraphDef()
>>> gf.ParseFromString(open('/your/path/to/graphname.pb','rb').read())
然后
>>> [n.name + '=>' + n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]
然后,你可以得到类似的结果:
['Mul=>Placeholder', 'final_result=>Softmax']
但我不确定是否存在有关错误消息的节点名称问题。 我猜你在加载图形文件或生成的图形文件有问题时提供了错误的论据?
检查此部分:
E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible
shapes: [1,224,224,3] vs. [32,1,1,2048]
<强>更新强> 抱歉, 如果你正在使用(重新)训练的图表,那么试试这个:
[n.name + '=>' + n.op for n in gf.node if n.op in ( 'Softmax','Mul')]
似乎(重新)训练的图表将输入/输出操作名称保存为&#34; Mul&#34;和&#34; Softmax&#34;,同时优化和/或量化图将它们保存为&#34;占位符&#34;和&#34; Softmax&#34;。
根据Peter Warden的帖子https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/,不建议在移动环境中使用重新训练的图表 BTW,。由于性能和文件大小问题,最好使用量化或memmapped图形,我无法找到如何在android中加载memmapped图形虽然...... :( (在android中加载优化/量化图没问题)答案 1 :(得分:7)
最近我直接从tensorflow看到了这个选项:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
--in_graph=custom_graph_name.pb
答案 2 :(得分:0)
我写了一个简单的脚本来分析计算图(通常是DAG,直接是非循环图)中的依赖关系。显而易见,输入是缺少输入的节点。但是,可以将输出定义为图中的任何节点,因为在最怪异但仍然有效的情况下,可以将输入作为输入,而将其他节点都设为虚拟。我仍然将输出操作定义为代码中没有输出的节点。您可以随心所欲地忽略它。
import tensorflow as tf
def load_graph(frozen_graph_filename):
with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
return graph
def analyze_inputs_outputs(graph):
ops = graph.get_operations()
outputs_set = set(ops)
inputs = []
for op in ops:
if len(op.inputs) == 0 and op.type != 'Const':
inputs.append(op)
else:
for input_tensor in op.inputs:
if input_tensor.op in outputs_set:
outputs_set.remove(input_tensor.op)
outputs = list(outputs_set)
return (inputs, outputs)