按类别过滤Tensorflow数据集

时间:2020-08-08 18:14:55

标签: python tensorflow tensorflow-datasets

我想过滤一个tensorflow数据集以仅输出特定类/标签的值。对于下面的代码,我们将如何做?

谢谢

image_feature_description = {
    'label': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([100, 100, 3], tf.float32),
}

def parse_tfrecord(example_proto):
  features = tf.io.parse_example(example_proto, image_feature_description)
  label = features['label']
  image = features['image']
  return image

  dataset = dataset.map(parse_tfrecord).batch(batch_size)

1 个答案:

答案 0 :(得分:2)

修改parse_tfrecord函数以同时返回标签和图像:

def parse_tfrecord(example_proto):
  ... # parsing an example
  return image, label

然后,添加一个仅在地图和批处理op之间保留label == MyLabel的过滤器:

dataset = dataset.map(parse_tfrecord) \
  .filter(lambda image, label: label == MY_LABEL) \
  .map(lambda image, label: image) \  # add this if you want image only
  .batch(...)