argmax的损失函数返回int

时间:2017-01-25 02:11:47

标签: tensorflow

我正在定义一个损失函数RMSE,如下所示:

model.pred= tf.argmax(model.props, 1)
model.actual = tf.argmax(model.y, 1)
model.RMSE = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(model.predictedSteer, model.actualSteer))))

model.pred是int64但sqrt给出了编译错误,因为它需要浮点输入。当我使用tf.cast时,损失函数变得不可区分。我该如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

argmax是不可微分的,因此涉及它的大多数函数也是不可微分的。在不了解问题的情况下,很难给出解决方案。如果是分类问题,请使用logits(pred)而不是argmax。