tf.argmax()用于多个索引Tensorflow

时间:2018-05-19 10:05:52

标签: python tensorflow multilabel-classification

在Tensorflow中,tf.argmax()返回数组中最大元素的索引。

但是,对于多标签分类任务,返回数组中N个最大元素的函数非常方便。

predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]

然后将它与一个热编码数组的基础事实进行比较

one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]

有这样的功能吗?或类似的东西?

感谢您的帮助

1 个答案:

答案 0 :(得分:2)

是的,有。它是tf.nn.top_k(来自here)。

您可以将其用作tf.nn.top_k(predicted_array, k=2)