使用names = [n.name for n in graph.as_graph_def().node]
我可以获取图表中的所有节点名称。
例如,说这打印:
['model/classifier/dense/kernel/Initializer/random_uniform/shape',
'model/classifier/dense/kernel/Initializer/random_uniform/min',
'model/classifier/dense/kernel/Initializer/random_uniform/max',
'model/classifier/dense/kernel/Initializer/random_uniform/RandomUniform',
'model/classifier/dense/kernel/Initializer/random_uniform/sub',
'model/classifier/dense/kernel/Initializer/random_uniform/mul',
'model/classifier/dense/kernel/Initializer/random_uniform',
'model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias/Initializer/zeros/shape_as_tensor',
'model/classifier/dense/bias/Initializer/zeros/Const',
'model/classifier/dense/bias/Initializer/zeros',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']
如何仅选择操作或仅选择张量?
我知道以下变通方法,这些变通方法在特定情况下有效,但不够通用,无法扩展到大图:
来自上述名称的字符串操作
例如获取model/classifier/dense/kernel
:
tensor = [graph.get_tensor_by_name(n + ":0")
for n in names if 'classifier' in n and
'kernel' in name and
not n.split('kernel')[-1]
][0]
try/except
我可以通过以下方式获得作为这些操作输出的张量:
tensors = []
for name in names:
try:
tensors.append(graph.get_tensor_by_name(name + ":0"))
except KeyError:
pass
kernel
Tensor怎么办?答案 0 :(得分:1)
如果您想要更好地了解操作和节点,请尝试运行tensorboard。
您可以使用tf.summary.FileWriter("folder_name", sess.graph)
编写摘要文件。
我的张量流量知识有限,但我认为张量名称和运算符名称几乎相同。操作员可以有多个输出,这些输出中的每一个都称为张量。因此,张量名称仅为operator_name:output_index
,output_index
通常为0
,因为大多数运算符都有单个输出。
所以给运行sess.graph.get_tensor_by_name("model/classifier/dense/kernel/Initializer/random_uniform/mul:0")
一个机会。我不确定这些长名字是否实用。
如果提供的信息不是100%正确,我很抱歉,我只是个初学者。
答案 1 :(得分:1)
好的,我找到了答案。它主要是因为我真正想要的是 Variables 而不仅仅是常规的Tensors。
因此它很简单:
with graph.as_default():
kernel = [v for v in tf.global_variables()
if 'optimization' not in v.name and
'classifier' in v.name
and 'kernel' in v.name
][0]
答案 2 :(得分:-1)
您可以使用isinstance(item,class)并将节点与tf.Operation类进行比较,如[n.name for n in graph.as_graph_def().node if isinstance(n, tf.Operation)]