tensorflow-如何通过数据集API输入目录

时间:2018-10-01 08:24:04

标签: python tensorflow

我是tensorflow的新手,这是我的情况:我有很多文件夹,每个文件夹包含多个图像。我需要训练输入为文件夹(每个2个文件夹),并且每次选择一个文件夹中的4张图像进行训练。 我尝试使用数据集api,并尝试使用mapflat_map函数,但无法读取文件夹中的图像。 这是我的代码的一部分:

def parse_function(filename):
    print(filename)
    batch_data = []
    batch_label = []
    dir_path = os.path.join(data_path, str(filename))
    imgs_list = os.listdir(dir_path)
    random.shuffle(imgs_list)
    imgs_list = imgs_list * 4 #each time select 4 images
    for i in range(img_num):
        img_path = os.path.join(dir_path, imgs_list[i])
        image_string = tf.read_file(img_path)  
        image_decoded = tf.image.decode_image(image_string)  
        image_resized = tf.image.resize_images(image_decoded, [224, 224])
        batch_data.append(image_resized)
        batch_label.append(label)
    return batch_data, batch_label
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) 
dataset = dataset.map(_parse_function) 

其中filename是文件夹名称的列表,例如“ 123456”,labels是标签的列表,例如0或1。

0 个答案:

没有答案