我在MSCOCO数据集上使用tensorflow,它有80个不同的类,每个图像可以包含多个类(多标签),这在我的输入中建模为每个图像的80个元素的列表,在右边包含1个他错了班级的班级和0。
我正在尝试通过检查预测的logits图层中的1的数量并将其元素与输入的真实标签进行比较来计算准确度。
使用tf.equal(logits , truelabels)
不适用,因为有80个类具有正确预测的大量零。所以我最终得到了高精度数和错误的预测。
我尝试使用tf.map_fn
将张量分解为元素,但显然我不能在函数内部使用类似变量的计数器来从真正的标签批处理中选择当前的真实标签({{1} })。
这是我的代码:
Y_