现在我正在执行以下操作来访问图表中的张量
graph = tf.get_default_graph()
add_0 = graph.get_tensor_by_name("Add:0")
add_1 = graph.get_tensor_by_name("Add_1:0")
add_2 = graph.get_tensor_by_name("Add_2:0")
当图表很短时,这种方法是可以的。但对于较长的图表,它变得非常无聊。
有没有办法以干净的方式收集以Add
开头的所有张量?类似的东西:
add = []
for Add in graph.get_tensors_by_name():
add.append(Add)
(我知道这个伪代码确实是错误的)
这样我得到add = [add_0, add_1, add_2, ... ]
后来我想用它来做到这一点:sess.run(add, feed_dict={input: data})
答案 0 :(得分:2)
您可以使用sess.graph.get_operations()
获取所有张量,然后使用startswith()
选择您需要的张量。经过测试的代码:
import tensorflow as tf
a = tf.constant( [ 1.0 ] )
b = tf.constant( [ 2.0 ] )
c = tf.add( a, b )
d = tf.add( c, b )
with tf.Session() as sess:
tensors = sum( [ operation.outputs
for operation in sess.graph.get_operations()
if operation.name.startswith( "Add") ],
[] )
print( tensors )
print( sess.run( tensors ) )
输出:
tf.Tensor'添加:0'形状=(1,)dtype = float32,tf.Tensor'Add_1:0'shape =(1,)dtype = float32
[array([3。],dtype = float32),array([5。],dtype = float32)]