我想从tf.nn.softmax()
的输出中获取最大(2个或更多)索引设置为1。
给定tf.nn.softmax
的输出为[0.1, 0.4, 0.2, 0.1, 0.8]
,我想得到类似[0,1,0,0,1]
的东西,因为这些索引具有最大数目(在这种情况下,我只选择了最大2)。预先谢谢你!
答案 0 :(得分:0)
tf.nn.softmax
强制将所有内容加起来1.0
进行有效的概率分布。如果要将向量中的多个值设为1,则应改用tf.nn.sigmoid
。
如果要检索向量中的最大值,请使用tf.nn.top_k
。
答案 1 :(得分:0)
您可以使用tf.nn.top_k
来返回输入向量的最大值及其位置。
probs = tf.nm.softmax(logits)
k = 2 # the first k=2 highest values
indices, values = tf.nn.top_k(probs, k=k)