在Tensorflow图中检查节点是操作还是张量

时间:2018-06-15 15:54:24

标签: python tensorflow

使用names = [n.name for n in graph.as_graph_def().node]我可以获取图表中的所有节点名称。

例如,说这打印:

['model/classifier/dense/kernel/Initializer/random_uniform/shape',
'model/classifier/dense/kernel/Initializer/random_uniform/min',
'model/classifier/dense/kernel/Initializer/random_uniform/max',
'model/classifier/dense/kernel/Initializer/random_uniform/RandomUniform',
'model/classifier/dense/kernel/Initializer/random_uniform/sub',
'model/classifier/dense/kernel/Initializer/random_uniform/mul',
'model/classifier/dense/kernel/Initializer/random_uniform',
'model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias/Initializer/zeros/shape_as_tensor',
'model/classifier/dense/bias/Initializer/zeros/Const',
'model/classifier/dense/bias/Initializer/zeros',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd'] 

如何仅选择操作或仅选择张量?

我知道以下变通方法,这些变通方法在特定情况下有效,但不够通用,无法扩展到大图:

  • 来自上述名称的字符串操作

    • 例如获取model/classifier/dense/kernel

      tensor = [graph.get_tensor_by_name(n + ":0") 
          for n in names if 'classifier' in n and
          'kernel' in name and 
          not n.split('kernel')[-1]
      ][0]
      
    • 正如您可以想象的那样,这可能非常容易出错并且具有特定的张量
  • try/except我可以通过以下方式获得作为这些操作输出的张量:

    tensors = []
    for name in names:
        try:
            tensors.append(graph.get_tensor_by_name(name + ":0"))
        except KeyError:
            pass
    
    • 但这又是一个解决方法,并没有解决选择问题:如果我只想要kernel Tensor怎么办?

3 个答案:

答案 0 :(得分:1)

如果您想要更好地了解操作和节点,请尝试运行tensorboard。 您可以使用tf.summary.FileWriter("folder_name", sess.graph)编写摘要文件。

我的张量流量知识有限,但我认为张量名称和运算符名称几乎相同。操作员可以有多个输出,这些输出中的每一个都称为张量。因此,张量名称仅为operator_name:output_indexoutput_index通常为0,因为大多数运算符都有单个输出。 所以给运行sess.graph.get_tensor_by_name("model/classifier/dense/kernel/Initializer/random_uniform/mul:0")一个机会。我不确定这些长名字是否实用。

如果提供的信息不是100%正确,我很抱歉,我只是个初学者。

答案 1 :(得分:1)

好的,我找到了答案。它主要是因为我真正想要的是 Variables 而不仅仅是常规的Tensors。

因此它很简单:

with graph.as_default():
    kernel = [v for v in tf.global_variables()
        if 'optimization' not in v.name and
        'classifier' in v.name
        and 'kernel' in v.name
    ][0]

答案 2 :(得分:-1)

您可以使用isinstance(item,class)并将节点与tf.Operation类进行比较,如[n.name for n in graph.as_graph_def().node if isinstance(n, tf.Operation)]