我要构建的数据集如下:
selected_cls = 1
# s_img: shape: [50,256,256], the images to use
# s_lbl: shape: [50,256,256], the ground truth masks
# s_cls_idx: shape: [50,], binary class, e.g. [0, 1, 1, 0, ...]
dataset = tf.data.Dataset.from_tensor_slices((s_img, s_lbl, s_cls_idx))
.filter(lambda img, lbl, cls_idx: cls_idx == selected_cls))
.batch(4)
it = iter(dataset)
next = next(it)
for 1 in range(4):
selected_cls = np.random.randint(2)
img, lbl, cls_idx = next
print(selected_cls, cls_idx)
我想要的输出是:
1,1
0,0
0,0
1,1
实际输出是
1,1
0,1
0,1
1,1
似乎一旦我构建了dataset
,它就将selected_cls
固定为1,这不是我想要的。
相反,我想为每个批次手动指定selected_cls
,并且dataset
可以提供与selected_cls
相同标签的数据。
dataset.make_initializable_iterator()
构建数据集,并且可以在每个批次中获取指定类的数据。但是它必须在Session()
中初始化,这不是我在TF 2.0中想要的。tf.data.Dataset.from_generator(gen_series)
来构建数据集并在gen_series
内编写一个过滤器函数,它可以工作,但是我认为这不是一种优雅的方法。我仍然想使用tf.data.Dataset.from_tensor_slices(...).filter()
进行过滤。这非常接近Filter Dataset to get just images from specific class,但与我的不同。