找到两个张量的交集。返回两个输入张量中的已排序的唯一值

时间:2018-01-17 14:23:18

标签: python numpy tensorflow

你好Tensorflow初学者,

我想删除实现中的任何numpy代码,只使用tensorflow函数。目前我正试图以低置信度分数筛选出背景边界框和框。为此,我想要一个名为 keep 的索引,我可以用它来跟踪要保留的框:

# Filter out background boxes
keep = np.where(class_ids > 0)[0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
    keep = np.intersect1d(
        keep, np.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])

class_ids是一个形状的张量(1000,),其中每个条目是0到80之间的数字,具体取决于类(总共81个类)。

class_scores是一个形状的张量(1000,),其中每个条目是相应边界框的类的概率。

我知道np.where()很容易更改为tf.where但是如何使用tensorflow获得与np.intersect1d()相同的功能?

感谢您的帮助。

1 个答案:

答案 0 :(得分:3)

这似乎复制了numpy.intersect1d示例。

import tensorflow as tf

a = tf.constant([3, 1, 2, 1])
b = tf.constant([1, 3, 4, 3])

# This set appears to be sorted, but that is not documented behavior.
s = tf.sets.set_intersection(a[None,:], b[None, :])
fsort = tf.contrib.framework.sort(s.values)

with tf.Session() as sess:
    print(sess.run(s).values)
    print(sess.run(fsort))

此输出

[1 3]
[1 3]

通过一些测试示例,set函数似乎给出了有序结果,但我无法验证它是否总会这样做。所以,你可能想要使用contrib函数来确定。