如何检查张量的argmax是否等于另一个张量的argmax,它具有几个相等的最大值?

时间:2018-03-30 23:10:10

标签: tensorflow indices tensor argmax

通常在单一标签分类中,我们使用以下

correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))

但我正在使用多标签分类,所以我想知道如何在标签向量中有几个这样做。所以到目前为止我所拥有的是

 a = tf.constant([0.2, 0.4, 0.3, 0.1])
 b = tf.constant([0,1.0,1,0])
 empty_tensor = tf.zeros([0])
 for index in range(b.get_shape()[0]):
     empty_tensor = tf.cond(tf.equal(b[index],tf.constant(1, dtype = 
     tf.float32)), lambda:  tf.concat([empty_tensor,tf.constant([index], 
     dtype = tf.float32)], axis = 0), lambda: empty_tensor)

 temp, _ = tf.setdiff1d([tf.argmax(a)], tf.cast(empty_tensor, dtype= tf.int64))
 output, _ = tf.setdiff1d([tf.argmax(a)], tf.cast(temp, dtype = tf.int64))

所以这给了我max(preds)发生的指标以及self.label中的1。在上面的例子中,它给出[1],如果argmax不匹配,那么我得到[]。

我遇到的问题是我不知道如何从那里开始,因为我想要像下面这样的东西

correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))
self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

这对于单一标签分类很简单。

非常感谢

1 个答案:

答案 0 :(得分:1)

我认为你不能用softmax实现这个目标,所以我假设你正在使用sigmoids来实现你的预测。如果您使用的是sigmoids,则输出将是(独立地)介于0和1之间。您可以为每个输出定义阈值,可能为0.5,然后将sigmoid preds转换为label编码(通过preds > 0.5来完成0和1。

如果预测为[0 1]且标签为[1 1],您要报告完全或部分错误吗?我将承担前者。在这种情况下,您会删除tf.argmax调用,而是检查predslabel是否完全相同,这看起来像tf.reduce_all(tf.equal(preds, label), axis=0)。对于后者,代码看起来像tf.reduce_sum(tf.equal(preds, label), axis=0)