你好Tensorflow初学者,
我想删除实现中的任何numpy代码,只使用tensorflow函数。目前我正试图以低置信度分数筛选出背景边界框和框。为此,我想要一个名为 keep 的索引,我可以用它来跟踪要保留的框:
# Filter out background boxes
keep = np.where(class_ids > 0)[0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = np.intersect1d(
keep, np.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])
class_ids是一个形状的张量(1000,),其中每个条目是0到80之间的数字,具体取决于类(总共81个类)。
class_scores是一个形状的张量(1000,),其中每个条目是相应边界框的类的概率。
我知道np.where()很容易更改为tf.where但是如何使用tensorflow获得与np.intersect1d()相同的功能?
感谢您的帮助。
答案 0 :(得分:3)
这似乎复制了numpy.intersect1d示例。
import tensorflow as tf
a = tf.constant([3, 1, 2, 1])
b = tf.constant([1, 3, 4, 3])
# This set appears to be sorted, but that is not documented behavior.
s = tf.sets.set_intersection(a[None,:], b[None, :])
fsort = tf.contrib.framework.sort(s.values)
with tf.Session() as sess:
print(sess.run(s).values)
print(sess.run(fsort))
此输出
[1 3]
[1 3]
通过一些测试示例,set函数似乎给出了有序结果,但我无法验证它是否总会这样做。所以,你可能想要使用contrib函数来确定。