TensorFlow:一种有效的方法来获取张量中不为零的最小元素的索引吗?

时间:2019-05-19 09:40:21

标签: python tensorflow boolean tensor

我使用TensorFlow 1.12。我有一个一维张量tag_mask_sizes,该张量主要包含零,但也包含一些正整数。我如何有效地获取不为零的最小元素的索引?我尝试了以下方法:

tag_mask_sizes_suppressed = tf.map_fn(lambda x: x if tf.not_equal(x, tf.constant(0, dtype=tf.uint8)) else 9999999, tag_mask_sizes)
        smallest_mask_index = tf.argmin(tag_mask_sizes_suppressed)

但是,tf.not_equal()产生了一个布尔张量,我无法在lambda的if-else条件下有效地对其进行评估。还有其他类似的优雅解决方案吗?

虽然我通常会急切地执行,但是这个问题发生在我在tf.Dataset.map()中使用的函数中,而这个函数并没有急切地执行。

1 个答案:

答案 0 :(得分:1)

实际上,您的代码等效于以下代码。

tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

向量化方法将比tf.map_fn()快得多。另外,还有一些矢量化方法来获取一维张量中最小元素的索引,该索引不为零。一个例子:

import tensorflow as tf
# tf.enable_eager_execution()

tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)

# approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

# approach 2, only the first minimum can be found.
tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]

# approach 3, find all minimum
tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))

with tf.Session() as sess:
    print(sess.run(smallest_mask_index1))
    print(sess.run(smallest_mask_index2))
    print(sess.run(smallest_mask_index3))

# print
2
[2]
[2 4]