我想用自己的数据集修改SIMCLR的代码,该数据集包含目录“ train”和两个子目录:“ dog”和“ cat”,
def read_batch(self, current_epoch, num_epochs):
"""Reads text file and processes data either for pretraining, classification or segmentation.
This will read in the text file and depending on the task will return augmented images or if
the task is classification or segmentation the class or image labels. A counter for the epoch
is also returned. This increments after each epoch.
Args:
current_epoch: Epoch at which epoch counter begins
num_epochs: number of epochs for epoch counter
Returns:
data: tuple containing current epoch and the augmented images and labels if required
"""
textdata = tf.data.TextLineDataset(self.file_path)
textdata = textdata.shuffle(5000, reshuffle_each_iteration=True)
epoch_counter = tf.data.Dataset.range(current_epoch, num_epochs)
if self.config.task == 'pretrain':
data = textdata.map(lambda fnames: self.read(fnames))
data = data.map(lambda image: self.augment(image))
data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data, tf.data.Dataset.from_tensors(i).repeat())))
else:
textdata = textdata.map(lambda line: tf.strings.split(line, ','))
data_image = textdata.map(lambda fnames: self.read(fnames[0]))
if self.config.task == 'classification' or self.config.task == 'regression':
data_label = textdata.map(lambda classes: classes[1])
data_label = data_label.map(lambda classes: tf.strings.to_number(classes))
elif self.config.task == 'segmentation':
data_label = textdata.map(lambda fnames: self.read(fnames[1]))
data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data_image, data_label, tf.data.Dataset.from_tensors(i).repeat())))
data = data.batch(self.batch_size, drop_remainder=True)
return data
在这里,输入是一个文本文件,图像名称用带有相应标签的“,”分隔。
我试图这样做:
textdata = tf.data.Dataset.from_tensor_slices(self.file_path)
textdata = textdata.shuffle(5000, reshuffle_each_iteration=True)
epoch_counter = tf.data.Dataset.range(current_epoch, num_epochs)
if self.config.task == 'pretrain':
data = textdata.map(lambda fnames: self.read(fnames))
data = data.map(lambda image: self.augment(image))
data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data, tf.data.Dataset.from_tensors(i).repeat())))
else:
#textdata = textdata.map(lambda fnames: tf.strings.split(fnames, os.path.sep))
#data_image = textdata.map(lambda fnames: self.read(fnames[0]))
if self.config.task == 'classification' or self.config.task == 'regression':
data_image = tf.data.Dataset.from_tensor_slices(self.file_path)
data_image = data_image.map(lambda fnames: self.read(fnames))
data_label = textdata.map(lambda fnames: self.process_path(fnames))
我发现此错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [224,224,3] and element 2 had shape [2843,2622,3].