对于操作节点,具有控制依赖性的tf.identity的等效性

时间:2016-06-20 19:58:15

标签: python-2.7 tensorflow

我正在编写一个包装类,它带有一个带有特殊成员的通用图表" train_op"管理我的模型的培训,保存和管理。

我想干净地跟踪生命周期的训练步数,如下:

with tf.control_dependencies([ step_add_one ]):
    self.train_op=tf.identity(self.training_graph.train_op )
  

引发TypeError('预期的二进制或unicode字符串,得到%r'       e,is_training = True,输入=无)

我认为这里的问题是train_op是tf.Optimizer.minimize()的返回,所以它本身不是张量,而是一个操作。

一个明显的解决方法是在training_graph.loss上调用tf.identity,但是我失去了一点抽象,因为我必须在外部处理学习率等。而且,我觉得我错过了一些东西。

我怎样才能最好地解决这个问题?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.group(),它可用于操作张量。

例如:

x = tf.Variable(1.)
loss = tf.square(x)
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(loss)

step = tf.Variable(0)
step_add_one = step.assign_add(1)

with tf.control_dependencies([step_add_one]):
    train_op_2 = tf.group(train_op)

现在,当您运行train_op_2时,step的值将会增加。

但是,最好的方法(如果您可以修改创建图表的图表)是向global_step函数添加参数minimize

train_op = optimizer.minimize(loss, global_step=step)