从CSV文件读取图像并返回tf.data.Dataset对象的有效方法

时间:2019-07-05 15:37:36

标签: python tensorflow

我有一个包含两列的csv文件:

  1. 存储为numpy数组的图像的文件路径
  2. 图片标签

csv中的每一行对应一个项目(样本)。

我想创建一个tf.data管道来读取文件路径并加载numpy数组和与其关联的标签。我该怎么做才能返回tf.data.Dataset对象?

网站上的documentation信息不多,我不知道从哪里开始。

2 个答案:

答案 0 :(得分:1)

一种方法是简单地将这两个文件加载到变量中并使用tf.data.Dataset.from_tensor_slices(请参见https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays

另一种方法是将文件路径映射到数据集中并进行数据流水线读取并以(img,label)返回 这是https://www.tensorflow.org/tutorials/load_data/images

中的示例代码
def load_and_preprocess_image(path):
  image = tf.read_file(path)
  return preprocess_image(image)

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

# The tuples are unpacked into the positional arguments of the mapped function
def load_and_preprocess_from_path_label(path, label):
  return load_and_preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)

如果数据对于内存来说太大,我自己会选择第二种方法,但是对于小数据,第一种方法很方便

答案 1 :(得分:0)

本教程应该是一个不错的起点:https://www.tensorflow.org/tutorials/load_data/images

如链接中所述,加载图像路径及其标签。创建具有路径及其标签的from_tensor_slices的数据集,然后使用预处理功能将路径(字符串)映射为张量图像。

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

# The tuples are unpacked into the positional arguments of the mapped function
def load_and_preprocess_from_path_label(path, label):
  return load_and_preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds

请按照本教程的逐步说明进行操作。如果将图像保存为numpy数组而不是jpg文件,则必须更改一些预处理,但总体流程应该非常相似。