我想过滤一个tensorflow数据集以仅输出特定类/标签的值。对于下面的代码,我们将如何做?
谢谢
image_feature_description = {
'label': tf.io.FixedLenFeature([], tf.string),
'image': tf.io.FixedLenFeature([100, 100, 3], tf.float32),
}
def parse_tfrecord(example_proto):
features = tf.io.parse_example(example_proto, image_feature_description)
label = features['label']
image = features['image']
return image
dataset = dataset.map(parse_tfrecord).batch(batch_size)
答案 0 :(得分:2)
修改parse_tfrecord函数以同时返回标签和图像:
def parse_tfrecord(example_proto):
... # parsing an example
return image, label
然后,添加一个仅在地图和批处理op之间保留label == MyLabel的过滤器:
dataset = dataset.map(parse_tfrecord) \
.filter(lambda image, label: label == MY_LABEL) \
.map(lambda image, label: image) \ # add this if you want image only
.batch(...)