在Tensorflow中,如何使用索引创建条件?

时间:2018-08-24 07:19:33

标签: python tensorflow

我想使用基于索引的条件创建张量,在我的情况下为最小值:

min_index = tf.argmin(tensor0)
new_tensor = tensor0*3 if index==min_index or tensor0*2

例如:

tensor0 = [4, 5, 1, 3]
armin(tensor0) = 2
condition = [False, False, True, False]  <-- How to make this one?
new_tensor = [4*2, 5*2, 1*3, 3*2]

我不能使用tf.equal,因为我的值是浮点数,而不是像这个简单示例一样的整数。

1 个答案:

答案 0 :(得分:0)

我想我明白了!我缺少的是“ sparse_to_dense”功能。希望它可以帮助某人:

import tensorflow as tf

tensor0 = tf.constant([4, 5, 1, 3], tf.float32)

argmin = tf.argmin(tensor0)

binary_mask = tf.cast(tf.sparse_to_dense(argmin, tensor0.shape, 1), tf.bool)

new_tensor = tf.where(binary_mask, tensor0*3, tensor0*2)