根据给定的阈值过滤多个图像预测

时间:2019-12-10 08:15:03

标签: tensorflow

尝试基于给定阈值过滤张量时遇到了一些麻烦。

我有一个包含两个图像的批处理:

  

[2,300,300,3]

我对此批次进行预测并收到以下值:

  

boxes [2,300,4]

     

得分[2,300]

我想在一个基于阈值的阈值上过滤分数,比如说0.10,我该如何过滤分数然后过滤相应的框?

所以输出看起来像:

  

output [2,50,4](当过滤后还剩下50个盒子时)

谢谢。

2 个答案:

答案 0 :(得分:1)

如果您需要TF解决方案,则应该可以进行以下操作。

import tensorflow as tf
import numpy as np

mask = tf.math.greater(scores,0.1)
boxes_above_thresh = tf.boolean_mask(boxes, mask)
scores_above_thresh = tf.boolean_mask(scores, mask)
with tf.Session() as sess:
  res = sess.run([scores_above_thresh, boxes_above_thresh])

话虽如此,这将返回一个(number of boxes, 4)类型的数组。也就是说,它不会返回(2, 50, 4)数组,而是返回(100, 4)数组。

编辑:获取具有(2, x, 4)类型输出的张量

我认为您正在寻找tf.ragged_tensor。哪个适合您的目的。以下解决方案将导致tf.RaggedTensorValue

mask = tf.math.greater(scores,0.1)
boxes_above_thresh = tf.ragged.boolean_mask(boxes, mask)
scores_above_thresh = tf.ragged.boolean_mask(scores, mask)

with tf.Session() as sess:

  res1, res2 = sess.run([scores_above_thresh, boxes_above_thresh])

答案 1 :(得分:0)

我认为这可以解决问题:

import numpy as np
filtered_indices=np.where(scores>thresh)
filtered_scores=x[filtered_indices]
filtered_boxes=boxes[filtered_indices]