从图像本地目录创建tensorflow数据集

时间:2019-02-08 10:26:25

标签: python tensorflow dataset

我在本地拥有一个非常庞大的图像数据库,其数据分布就像每个文件夹都包含一类图像。

我想使用tensorflow数据集API来获取批处理数据,而无需将所有图像加载到内存中。

我尝试过这样的事情:

def _parse_function(filename, label):
    image_string = tf.read_file(filename, "file_reader")
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    return image, label

image_list, label_list, label_map_dict = read_data()

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(image_list), tf.constant(label_list)))
dataset = dataset.shuffle(len(image_list))
dataset = dataset.repeat(epochs).batch(batch_size)

dataset = dataset.map(_parse_function)

iterator = dataset.make_one_shot_iterator()

image_list是一个列表,其中已附加图像的路径(和名称),label_list是一个列表,其中已按相同顺序附加了每个图像的类。

但是_parse_function不起作用,我收到的错误是:

ValueError:形状必须为0,但输入形状为[?]的'file_reader'(op:'ReadFile')则为1。

我已经搜索了错误,但对我没有任何帮助。

如果我不使用map函数,我只是获取图像的路径(存储在image_list中),因此我认为我需要map函数来读取图像,但是我无法做到这一点有效。

谢谢。

编辑:

    def read_data():
        image_list = []
        label_list = []
        label_map_dict = {}
        count_label = 0

        for class_name in os.listdir(base_path):
            class_path = os.path.join(base_path, class_name)
            label_map_dict[class_name]=count_label

            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)

                label_list.append(count_label)
                image_list.append(image_path)

            count_label += 1

1 个答案:

答案 0 :(得分:0)

错误在此行dataset = dataset.repeat(epochs).batch(batch_size)中,您的管道将batchsize添加为要输入的维。

您需要在像这样的地图函数之后对数据集进行批处理

    dataset = tf.data.Dataset.from_tensor_slices((tf.constant(image_list), tf.constant(label_list)))
    dataset = dataset.shuffle(len(image_list))
    dataset = dataset.repeat(epochs)
    dataset = dataset.map(_parse_function).batch(batch_size)