我正在定义一个损失函数RMSE,如下所示:
model.pred= tf.argmax(model.props, 1)
model.actual = tf.argmax(model.y, 1)
model.RMSE = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(model.predictedSteer, model.actualSteer))))
model.pred是int64但sqrt给出了编译错误,因为它需要浮点输入。当我使用tf.cast时,损失函数变得不可区分。我该如何解决这个问题?
答案 0 :(得分:0)
argmax
是不可微分的,因此涉及它的大多数函数也是不可微分的。在不了解问题的情况下,很难给出解决方案。如果是分类问题,请使用logits(pred
)而不是argmax。