Tensorflow中图表中的张量名称列表

时间:2016-02-11 10:24:57

标签: python tensorflow

Tensorflow中的图形对象有一个名为“get_tensor_by_name(name)”的方法。无论如何都要获得有效张量名称列表?

如果没有,是否有人知道预训练模型inception-v3 from here的有效名称?从他们的例子中,pool_3是一个有效的张量,但是所有这些的列表都很好。我查看the paper referred to,其中一些图层似乎与表1中的尺寸相对应,但不是全部。

6 个答案:

答案 0 :(得分:50)

该论文没有准确反映模型。如果从arxiv下载源代码,则它具有精确的模型描述作为model.txt,其中的名称与发布模型中的名称密切相关。

要回答您的第一个问题,sess.graph.get_operations()会为您提供操作列表。对于op,op.name为您提供名称,op.values()为您提供它生成的张量列表(在inception-v3模型中,所有张量名称都是附加了“:0”的操作名称)它,所以pool_3:0是由最终汇集操作产生的张量。)

答案 1 :(得分:20)

要查看图中的操作(你会看到很多,所以要缩短我在这里只给出了第一个字符串)。

sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]

out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)

答案 2 :(得分:17)

以上答案是正确的。我遇到了一个易于理解/简单的代码来完成上述任务。所以在这里分享: -

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)


printTensors("path-to-my-pbfile.pb")

答案 3 :(得分:7)

您甚至不必创建会话来查看图表中所有操作名称的名称。要执行此操作,您只需获取默认图表tf.get_default_graph()并提取所有操作:.get_operations。每个操作都有many fields,您需要的是名称。

以下是代码:

import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c

for i in tf.get_default_graph().get_operations():
    print i.name

答案 4 :(得分:2)

作为嵌套列表理解:

tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]

在图表中获取张量名称的功能(默认为默认图表):

def get_names(graph=tf.get_default_graph()):
    return [t.name for op in graph.get_operations() for t in op.values()]

在图表中获取张量的函数(默认为默认图形):

def get_tensors(graph=tf.get_default_graph()):
    return [t for op in graph.get_operations() for t in op.values()]

答案 5 :(得分:0)

saved_model_cli是TF附带的另一种命令行工具,如果您处理“ SavedModel”格式,可能很有用。从docs

!saved_model_cli show --dir /tmp/mobilenet/1 --tag_set serve --all

此输出可能有用,类似于:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['dense_input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1280)
        name: serving_default_dense_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['dense_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict