转换tf.data.dataset

时间:2020-03-16 18:51:13

标签: python tensorflow machine-learning

假设我有一个32 * 32 * 3类型图像的数据集作为源数据:

<DatasetV1Adapter shapes: {coarse_label: (), image: (32, 32, 3), label: ()}, types: {coarse_label: tf.int64, image: tf.uint8, label: tf.int64}>

序列化数据后,我得到:

<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

我可以使用这段代码访问每个元素:

for i in parsed_image_dataset.take(1):
  j=i['image_raw']
array_shape = e1['image'].numpy().shape
print(np.frombuffer(j.numpy(), dtype = 'uint8').reshape(array_shape))

其中e1是在原始数据集中使用get_next生成的。因此,如预期的那样,印刷品将打印出与一次预序列化相同的图像。但是,我可以以某种方式代替逐个元素地执行此操作立即将我的序列化数据集转换为原始uint8数据集?

1 个答案:

答案 0 :(得分:0)

您可以按照以下步骤在uint8中获取图像。

创建序列化数据。

list_ds = tf.data.Dataset.list_files("img_dir_path/*")

创建一个将file_path作为参数并以uint8格式返回图像的函数。

def process_img(file_path):
  img = tf.io.read_file(file_path)

  img = tf.image.decode_jpeg(img, channels=3)
  return img

使用map函数将上述函数应用于list_ds对象中的所有项目。

processed_images = list_ds.map(process_img)

processed_images 将包含给定图像目录的uint8格式的图像。