Tensorflow:在丢失操作时使用tf.cond的问题

时间:2017-06-13 10:39:04

标签: python machine-learning tensorflow control-flow

我试图在训练阶段使用加权损失,但不是在验证阶段。我的想法是使用tf.cond和占位符is_training,我将使用feed_dict(真或假)提供

loss = tf.cond(is_training,
               lambda: tf.losses.mean_squared_error(labels, predictions, scope="loss", weights=loss_weights),
               lambda: tf.losses.mean_squared_error(labels, predictions, scope="loss"))

当我单独测试错误时它会起作用,但是一旦我使用tf.cond我得到:

tensorflow.python.framework.errors_impl.InvalidArgumentError: The tensor returned for train_op:0 was not valid.

loss操作在训练期间使用如下(在验证期间,它只是由sess.run()调用): 将正则化损失添加到tf.collection并添加所有损失。然后:

gradients = optimizer.compute_gradients(total_loss, variables_to_train)
grad_updates = optimizer.apply_gradients(gradients,global_step=global_step)

我该如何使用tensorflow控制流程?

修改:tensorflow-gpu 1.0.1

0 个答案:

没有答案