将自定义损失添加到eval_metric_ops

时间:2017-08-16 18:16:58

标签: python tensorflow google-cloud-ml-engine

我使用sequence_loss

定义了自己的损失函数
loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, labels, weights)

我希望将此添加到eval_metric_ops,以便在我的ML引擎包中,我可以连续显示tensorboard中的评估损失(默认值只是准确性)。我尝试将其添加为自定义eval_metric_ops

eval_metric_ops = {
    'loss': loss # this has already been coputed for Modes.EVAL
}

但是,我收到错误" TypeError:eval_metric_ops的值必须是(metric_value,update_op)元组,给定:Tensor(" sequence_loss / truediv:0",shape =(),dtype = float32)for key:loss" - 如何将损失作为eval_metric_op传递,我需要做什么?我猜我当前的损失应该是指标值,但我不确定update_op应该是什么?

1 个答案:

答案 0 :(得分:3)

您的案例中的指标函数可以使用tf.metrics使用损失的移动平均值来实现:

def metric_fn(labels, predict):
   loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, labels, weights)
   mean, op = tf.metrics.mean(loss)
   return mean, op