Tensorflow:使用张量中的值作为参数

时间:2017-01-26 15:08:19

标签: tensorflow

我想以不同的方式计算DNN中的损失函数,具体取决于标签的值。

概念上它是这样的:

def loss(logits, labels):

    if labels[0] == 0:
        return loss_function_1(logits, labels)
    else:
        return loss_function_2(logits, labels)

显然这不起作用,因为我不能对张量对象进行这种比较。我也无法使用eval(),因为我收到网络未定义的错误。我还有其他选择吗?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.cond结构:

tf.cond(labels[0] == 0, lambda: loss_function_1(logits, labels),
                        lambda: loss_function_2(logits, labels))