在Tensorflow中,获取图表中所有张量的名称

时间:2016-04-27 08:08:30

标签: python tensorflow tensorboard skflow

我正在使用Tensorflowskflow创建神经网络;出于某种原因,我想获得给定输入的一些内部张量的值,因此我使用myClassifier.get_layer_value(input, "tensorName")myClassifierskflow.estimators.TensorFlowEstimator

然而,我发现很难找到张量名称的正确语法,即使知道它的名字(我在操作和张量之间感到困惑),所以我使用张量板绘制图形并寻找名。

有没有办法在不使用张量板的情况下枚举图表中的所有张量?

10 个答案:

答案 0 :(得分:152)

你可以做到

[n.name for n in tf.get_default_graph().as_graph_def().node]

此外,如果您在IPython笔记本中进行原型设计,可以直接在笔记本中显示图表,请参阅Alexander的Deep Dream notebook中的show_graph函数

答案 1 :(得分:23)

使用get_operations,有一种方法可以比雅罗斯拉夫的答案稍快一些。这是一个简单的例子:

/src/resources/static/

答案 2 :(得分:11)

tf.all_variables()可以为您提供所需的信息。

此外,今天在TensorFlow Learn中制作的this commit在估算工具中提供了一个函数get_variable_names,您可以使用该函数轻松检索所有变量名称。

答案 3 :(得分:5)

我认为这也会这样做:

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

但与萨尔瓦多和雅罗斯拉夫的答案相比,我不知道哪一个更好。

答案 4 :(得分:5)

接受的答案只会为您提供一个包含名称的字符串列表。我更喜欢一种不同的方法,它可以(几乎)直接访问张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples现在包含每个张量,每个张量都在一个元组内。您也可以调整它以直接获得张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

答案 5 :(得分:3)

以前的答案很好,我只是想分享一个我用来从图表中选择张量的实用函数:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
    """Selects nodes' names in the graph if:
    - The name contains all items in and_conds
    - OR/AND depending on op
    - The name contains any item in or_conds

    Condition starting with a "!" are negated.
    Returns all ops if no optional arguments is given.

    Args:
        graph (tf.Graph): The graph containing sought tensors
        and_conds (list(str)), optional): Defaults to None.
            "and" conditions
        op (str, optional): Defaults to 'and'. 
            How to link the and_conds and or_conds:
            with an 'and' or an 'or'
        or_conds (list(str), optional): Defaults to None.
            "or conditions"

    Returns:
        list(str): list of relevant tensor names
    """
    assert op in {'and', 'or'}

    if and_conds is None:
        and_conds = ['']
    if or_conds is None:
        or_conds = ['']

    node_names = [n.name for n in graph.as_graph_def().node]

    ands = {
        n for n in node_names
        if all(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in and_conds
        )}

    ors = {
        n for n in node_names
        if any(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in or_conds
        )}

    if op == 'and':
        return [
            n for n in node_names
            if n in ands.intersection(ors)
        ]
    elif op == 'or':
        return [
            n for n in node_names
            if n in ands.union(ors)
        ]

所以如果你有一个带有ops的图表:

['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

然后运行

get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

返回:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

答案 6 :(得分:2)

我将尝试总结答案:

获取所有节点:

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

这些类型为tensorflow.core.framework.node_def_pb2.NodeDef

获取所有操作:

all_ops = tf.get_default_graph().get_operations()

这些类型为tensorflow.python.framework.ops.Operation

获取所有变量:

all_vars = tf.global_variables()

这些类型为tensorflow.python.ops.resource_variable_ops.ResourceVariable

最后,要回答这个问题,获得所有张量

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

这些类型为tensorflow.python.framework.ops.Tensor

答案 7 :(得分:1)

由于OP请求张量的列表而不是操作/节点的列表,因此代码应略有不同:

graph = tf.get_default_graph()    
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

答案 8 :(得分:0)

这对我有用:

for n in tf.get_default_graph().as_graph_def().node:
    print('\n',n)

答案 9 :(得分:0)

以下解决方案在TensorFlow 2.3中对我有效-

def load_pb(path_to_pb):
    with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
tf_graph = load_pb(MODEL_FILE)
sess = tf.compat.v1.Session(graph=tf_graph)

# Show tensor names in graph
for op in tf_graph.get_operations():
    print(op.values())

其中MODEL_FILE是冻结图的路径。

取自here