我创建了细分功能。 y
是mnist标签,y_
是softmax的预测结果,pen_less
和pen_more
是两个惩罚参数。
loss = tf.reduce_sum(tf.where(
tf.greater(tf.to_float(tf.argmax(y, 1)), tf.to_float(tf.argmax(y_, 1))),
tf.pow(pen_less, tf.to_float(tf.argmax(y, 1)) - tf.to_float(tf.argmax(y_, 1))),
tf.pow(pen_more, tf.to_float(tf.argmax(y, 1)) - tf.to_float(tf.argmax(y_, 1)))))
答案 0 :(得分:-1)
编辑:因此,如果传递所有三个参数,则tf.where
是可区分的。我认为您的问题出在argmaxes上:
import tensorflow as tf
x = tf.Variable([0, 1, 2])
tf.gradients(tf.argmax(x), x)
输出:
LookupError:未为操作“ ArgMax”(操作类型:ArgMax)定义渐变
如果您想要微分损失函数,则需要避免argmax运算或寻找一种获得合适伪梯度的明智方法。