在Tensorflow中,tf.argmax()返回数组中最大元素的索引。
但是,对于多标签分类任务,返回数组中N个最大元素的函数非常方便。
predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]
然后将它与一个热编码数组的基础事实进行比较
one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]
有这样的功能吗?或类似的东西?
感谢您的帮助