在Tensorflow中将除max之外的所有值归零

时间:2018-05-28 09:32:42

标签: python tensorflow

我有一个数组[0.3, 0.5, 0.79, 0.2, 0.11].

除了最大值之外,我想将所有值转换为零。因此得到的数组将是: [0, 0, 0.79, 0, 0]

在Tensorflow图表中执行此操作的最佳方法是什么?

3 个答案:

答案 0 :(得分:3)

如果要保留所有出现的最大值,可以使用

cond = tf.equal(a, tf.reduce_max(a))
a_max = tf.where(cond, a, tf.zeros_like(a))

如果您只想保留一次最大值,可以使用

argmax = tf.argmax(a)
a_max = tf.scatter_nd([[argmax]], [a[argmax]], tf.to_int64(tf.shape(a)))

然而根据the doc of tf.argmax

  

请注意,如果是关系,则无法保证返回值的标识

据我了解,保留的最大值可能不是第一个或最后一个 - 如果在同一个阵列上运行两次,则可能不一样。

答案 1 :(得分:2)

如果你想要tf.argmax()的行为,并且只想要一个,那么你可以这样做:

  

tf.sparse_to_dense(tf.argmax(a),tf.cast(tf.shape(a),dtype = tf.int64),   tf.reduce_max(a))的

a = tf.constant([0.3, 0.5, 0.79, 0.79, 0.11])

out = tf.sparse_to_dense(tf.argmax(a),tf.cast(tf.shape(a), dtype=tf.int64), tf.reduce_max(a))

<强>输出:

[0.   0.   0.79 0.   0.  ]

答案 2 :(得分:1)

使用max查找最大值:

my_list = [0.3, 0.5, 0.79, 0.2, 0.11]
m = max(my_list)

然后使用list-comprehension

[0 if i != m else i for i in my_list]

<强>输出:

[0, 0, 0.79, 0, 0]