在TensorFlow 2.0中按特定类别过滤数据集

时间:2019-12-03 10:22:07

标签: python tensorflow

问题

我要构建的数据集如下:

selected_cls = 1

# s_img: shape: [50,256,256], the images to use
# s_lbl: shape: [50,256,256], the ground truth masks
# s_cls_idx: shape: [50,], binary class, e.g. [0, 1, 1, 0, ...]

dataset = tf.data.Dataset.from_tensor_slices((s_img, s_lbl, s_cls_idx))
                         .filter(lambda img, lbl, cls_idx: cls_idx == selected_cls))
                         .batch(4)
it = iter(dataset)
next = next(it) 

for 1 in range(4):
  selected_cls = np.random.randint(2)
  img, lbl, cls_idx = next
  print(selected_cls, cls_idx)

我想要的输出是:

1,1
0,0
0,0
1,1

实际输出是

1,1
0,1
0,1
1,1

似乎一旦我构建了dataset,它就将selected_cls固定为1,这不是我想要的。 相反,我想为每个批次手动指定selected_cls,并且dataset可以提供与selected_cls相同标签的数据。

尝试

  1. 在TensorFlow 1.12中,我使用dataset.make_initializable_iterator()构建数据集,并且可以在每个批次中获取指定类的数据。但是它必须在Session()中初始化,这不是我在TF 2.0中想要的。
  2. 如果我使用tf.data.Dataset.from_generator(gen_series)来构建数据集并在gen_series内编写一个过滤器函数,它可以工作,但是我认为这不是一种优雅的方法。我仍然想使用tf.data.Dataset.from_tensor_slices(...).filter()进行过滤。

相关问题

这非常接近Filter Dataset to get just images from specific class,但与我的不同。

0 个答案:

没有答案