我是tensorflow的新手,这是我的情况:我有很多文件夹,每个文件夹包含多个图像。我需要训练输入为文件夹(每个2个文件夹),并且每次选择一个文件夹中的4张图像进行训练。
我尝试使用数据集api,并尝试使用map
或flat_map
函数,但无法读取文件夹中的图像。
这是我的代码的一部分:
def parse_function(filename):
print(filename)
batch_data = []
batch_label = []
dir_path = os.path.join(data_path, str(filename))
imgs_list = os.listdir(dir_path)
random.shuffle(imgs_list)
imgs_list = imgs_list * 4 #each time select 4 images
for i in range(img_num):
img_path = os.path.join(dir_path, imgs_list[i])
image_string = tf.read_file(img_path)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded, [224, 224])
batch_data.append(image_resized)
batch_label.append(label)
return batch_data, batch_label
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
其中filename是文件夹名称的列表,例如“ 123456”,labels是标签的列表,例如0或1。