TensorFlow:从丢失函数中排除一个类的正确方法

时间:2016-10-10 10:30:04

标签: machine-learning tensorflow classification

我有三个类1(活动),0(非活动)和-1(未知)。我想在TensorFlow中构建一个模型,它根据输入预测活动或非活动。以下是仅在活动和非活动标签上计算损失并忽略未知标签的正确方法吗?

logits = tf.reshape(logits, [-1])
labels = tf.reshape(labels, [-1])
index = tf.where(tf.not_equal(labels, tf.constant(-1, dtype=tf.float32)))
logits = tf.gather(logits, index)
labels = tf.gather(labels, index)
entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
loss = tf.reduce_mean(entropies)

0 个答案:

没有答案