tensorflow:在变量中找到大数值并将其替换

时间:2018-08-11 12:45:45

标签: python tensorflow softmax

我有一个tf.Variable()来自softmax,它是一系列概率,例如[0.3, 0.5, 0.8, 0.1, 0.2]。我试图做的就是将该序列转换为[0,0,1,0,0],即最高概率​​替换为1,所有其他概率替换为0。但是由于tf.Variable()不是迭代的,而{{ 1}}仅给出最大的价值,我该怎么做?

1 个答案:

答案 0 :(得分:0)

import tensorflow as tf

tf.enable_eager_execution()
softmax = tf.constant([0.3, 0.5, 0.8, 0.1, 0.2], dtype=tf.float32)
index = tf.argmax(softmax, axis=0, output_type=tf.int32)

# sparse = tf.SparseTensor(tf.reshape(index, [-1]), tf.constant([[1]], dtype=tf.int32), tf.shape(softmax))
result = tf.scatter_nd(tf.expand_dims(tf.expand_dims(index, axis=0), axis=0), tf.constant([1], dtype=tf.int32),
                       shape=tf.shape(softmax))