通过名称分配张量的值

时间:2016-06-03 21:50:58

标签: python tensorflow

我正在使用以下代码在TensorFlow图中抓取张量:

names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)

如何设置tensor

的值

1 个答案:

答案 0 :(得分:2)

大多数TensorFlow张量(tf.Tensor个对象)都是不可变的,因此您不能简单地为它们分配值。但是,如果您将张量创建为tf.Variable,则可以通过调用Variable.assign()为其指定值。

您不必要地将tf.Variable对象(来自tf.trainable_variables()列表)转换为字符串名称。相反,您可以执行以下操作:

# Get the 0th trainable variable.
var = tf.trainable_variables()[0]

# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])

# Actually run the assignment.
sess.run(assign_op)

然而,根据您的评论,您有多个图表(即self.graphgraph不同),因此我上面写的一般解决方案无效。在这种情况下,您有两个选择:

  1. 在其他图表中按名称获取变量( NB 这仅在填充了graph_2.get_collection('trainable_variables')时才有效;如果您使用了tf.import_graph_def(),它将无效构建图表):

    var_name = graph_1.get_collection('trainable_variables')[0].name
    
    var_in_g2 = [v for v in graph_2.get_collection('trainable_variables')
                 if v.name == var_name][0]
    
    assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]])
    sess.run(assign_op)
    
  2. 在其他图表中按名称获取张量,并使用tf.assign()

    var_name = graph_1.get_collection('trainable_variables')[0].name
    
    var_in_g2 = graph_2.get_tensor_by_name(var_name)
    
    assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]])
    sess.run(assign_op)