Tensorflow:以"添加"

时间:2018-04-26 11:58:23

标签: python tensorflow

现在我正在执行以下操作来访问图表中的张量

graph = tf.get_default_graph()
add_0 = graph.get_tensor_by_name("Add:0")
add_1 = graph.get_tensor_by_name("Add_1:0")
add_2 = graph.get_tensor_by_name("Add_2:0")

当图表很短时,这种方法是可以的。但对于较长的图表,它变得非常无聊。

有没有办法以干净的方式收集以Add开头的所有张量?类似的东西:

add = []
for Add in graph.get_tensors_by_name():
    add.append(Add)

(我知道这个伪代码确实是错误的)

这样我得到add = [add_0, add_1, add_2, ... ]

后来我想用它来做到这一点:sess.run(add, feed_dict={input: data})

1 个答案:

答案 0 :(得分:2)

您可以使用sess.graph.get_operations()获取所有张量,然后使用startswith()选择您需要的张量。经过测试的代码:

import tensorflow as tf

a = tf.constant( [ 1.0 ] )
b = tf.constant( [ 2.0 ] )
c = tf.add( a, b )
d = tf.add( c, b )

with tf.Session() as sess:

    tensors = sum( [ operation.outputs
                             for operation in sess.graph.get_operations() 
                             if operation.name.startswith( "Add") ],
                   [] )
    print( tensors )
    print( sess.run( tensors ) )

输出:

  

tf.Tensor'添加:0'形状=(1,)dtype = float32,tf.Tensor'Add_1:0'shape =(1,)dtype = float32
  [array([3。],dtype = float32),array([5。],dtype = float32)]