如何列出节点所依赖的所有Tensorflow变量?

时间:2017-02-15 18:09:49

标签: python tensorflow

如何列出节点所依赖的所有Tensorflow变量/常量/占位符?

示例1(添加常量):

CREATE TRIGGER AccidentsTrigger
ON [Accidents]
INSTEAD OF INSERT
AS BEGIN
    SET NOCOUNT ON

    IF (SELECT [second_road_class] FROM INSERTED) LIKE '-1'
    BEGIN
        UPDATE [Accidents]
        SET [second_road_class] = '6'
        WHERE [second_road_class] = '-1'
    END
END

我希望有一个函数import tensorflow as tf a = tf.constant(1, name = 'a') b = tf.constant(3, name = 'b') c = tf.constant(9, name = 'c') d = tf.add(a, b, name='d') e = tf.add(d, c, name='e') sess = tf.Session() print(sess.run([d, e])) ,例如:

  • list_dependencies()返回list_dependencies(d)
  • ['a', 'b']返回list_dependencies(e)

示例2(占位符和权重矩阵之间的矩阵乘法,然后添加偏差向量):

['a', 'b', 'c']

我希望有一个函数tf.set_random_seed(1) input_size = 5 output_size = 3 input = tf.placeholder(tf.float32, shape=[1, input_size], name='input') W = tf.get_variable( "W", shape=[input_size, output_size], initializer=tf.contrib.layers.xavier_initializer()) b = tf.get_variable( "b", shape=[output_size], initializer=tf.constant_initializer(2)) output = tf.matmul(input, W, name="output") output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias") sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]})) ,例如:

  • list_dependencies()返回list_dependencies(output)
  • ['W', 'input']返回list_dependencies(output_bias)

4 个答案:

答案 0 :(得分:9)

以下是我用于此的实用程序(来自https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

这些功能适用于操作。要获得产生张量t的操作,请使用t.op。要获得由op生成的张量,请使用op.outputs

答案 1 :(得分:3)

Yaroslav Bulatov's answer很棒,我只需添加一个使用雅罗斯拉夫get_graph()children()方法的绘图功能:

import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

plot_graph(get_graph())

从问题中绘制示例1:

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""
  print('get_graph')
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))
plot_graph(get_graph())

输出:

enter image description here

从问题中绘制示例2:

enter image description here

如果您使用的是Microsoft Windows,则可能会遇到此问题:Python Error (ValueError: _getfullpathname: embedded null character),在这种情况下,您需要修补matplotlib链接说明。

答案 2 :(得分:0)

在某些情况下,可能想要找到所有与“输出”张量相关的“输入”变量,例如图的丢失。为此,以下代码片段可能有用(受上述代码启发):

def findVars(atensor):
    allinputs=atensor.op.inputs
    if len(allinputs)==0:
        if atensor.op.type == 'VariableV2' or atensor.op.type == 'Variable':
            return set([atensor.op])
    a=set()
    for t in allinputs:
        a=a | findVars(t)
    return a

这可用于调试以找出图中的连接丢失的地方。

答案 3 :(得分:0)

这些都是很好的答案,我将添加一种简单的方法,以一种不太易读的格式生成依赖关系,但对于快速调试很有用。

tf.get_default_graph().as_graph_def()

将图形中的操作生成为简单字典的打印,如下所示。每个OP都易于通过其属性和输入来按名称进行查找,从而使您可以遵循依赖关系。

import tensorflow as tf

a = tf.placeholder(tf.float32, name='placeholder_1')
b = tf.placeholder(tf.float32, name='placeholder_2')
c = a + b

tf.get_default_graph().as_graph_def()

Out[14]: 
node {
  name: "placeholder_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "placeholder_2"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "add"
  op: "Add"
  input: "placeholder_1"
  input: "placeholder_2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 27
}