我正在使用以下代码在TensorFlow图中抓取张量:
names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)
如何设置tensor
?
答案 0 :(得分:2)
大多数TensorFlow张量(tf.Tensor
个对象)都是不可变的,因此您不能简单地为它们分配值。但是,如果您将张量创建为tf.Variable
,则可以通过调用Variable.assign()
为其指定值。
您不必要地将tf.Variable
对象(来自tf.trainable_variables()
列表)转换为字符串名称。相反,您可以执行以下操作:
# Get the 0th trainable variable.
var = tf.trainable_variables()[0]
# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])
# Actually run the assignment.
sess.run(assign_op)
然而,根据您的评论,您有多个图表(即self.graph
和graph
不同),因此我上面写的一般解决方案无效。在这种情况下,您有两个选择:
在其他图表中按名称获取变量( NB 这仅在填充了graph_2.get_collection('trainable_variables')
时才有效;如果您使用了tf.import_graph_def()
,它将无效构建图表):
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = [v for v in graph_2.get_collection('trainable_variables')
if v.name == var_name][0]
assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)
在其他图表中按名称获取张量,并使用tf.assign()
:
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = graph_2.get_tensor_by_name(var_name)
assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)