我想以不同的方式计算DNN中的损失函数,具体取决于标签的值。
概念上它是这样的:
def loss(logits, labels):
if labels[0] == 0:
return loss_function_1(logits, labels)
else:
return loss_function_2(logits, labels)
显然这不起作用,因为我不能对张量对象进行这种比较。我也无法使用eval()
,因为我收到网络未定义的错误。我还有其他选择吗?
答案 0 :(得分:1)
您可以使用tf.cond
结构:
tf.cond(labels[0] == 0, lambda: loss_function_1(logits, labels),
lambda: loss_function_2(logits, labels))