无法在组件0中批量处理具有不同形状的张量。第一个元素的形状为[224,224,3],元素2的形状为[2843,2622,3]

时间:2020-09-29 13:35:04

标签: python tensorflow tensorflow2.x

我想用自己的数据集修改SIMCLR的代码,该数据集包含目录“ train”和两个子目录:“ dog”和“ cat”,

def read_batch(self, current_epoch, num_epochs):
        """Reads text file and processes data either for pretraining, classification or segmentation.
        
        This will read in the text file and depending on the task will return augmented images or if
        the task is classification or segmentation the class or image labels. A counter for the epoch
        is also returned. This increments after each epoch.
         
        Args:
            current_epoch: Epoch at which epoch counter begins
            num_epochs: number of epochs for epoch counter
        Returns:
            data: tuple containing current epoch and the augmented images and labels if required
        """
        
        textdata = tf.data.TextLineDataset(self.file_path)
        textdata = textdata.shuffle(5000, reshuffle_each_iteration=True)

        epoch_counter = tf.data.Dataset.range(current_epoch, num_epochs)
        if self.config.task == 'pretrain':
            data = textdata.map(lambda fnames: self.read(fnames))
            data = data.map(lambda image: self.augment(image))
            data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data, tf.data.Dataset.from_tensors(i).repeat())))
        else:
            textdata = textdata.map(lambda line: tf.strings.split(line, ','))
            data_image = textdata.map(lambda fnames: self.read(fnames[0]))

            if self.config.task == 'classification' or self.config.task == 'regression':
                data_label = textdata.map(lambda classes: classes[1])
                data_label = data_label.map(lambda classes: tf.strings.to_number(classes))
            elif self.config.task == 'segmentation':
                data_label = textdata.map(lambda fnames: self.read(fnames[1]))            
        
            data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data_image, data_label, tf.data.Dataset.from_tensors(i).repeat())))

        data = data.batch(self.batch_size, drop_remainder=True)

        return data

在这里,输入是一个文本文件,图像名称用带有相应标签的“,”分隔。

我试图这样做:

        textdata = tf.data.Dataset.from_tensor_slices(self.file_path)
        textdata = textdata.shuffle(5000, reshuffle_each_iteration=True)

        epoch_counter = tf.data.Dataset.range(current_epoch, num_epochs)
        if self.config.task == 'pretrain':
            data = textdata.map(lambda fnames: self.read(fnames))
            data = data.map(lambda image: self.augment(image))
            data = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip((data, tf.data.Dataset.from_tensors(i).repeat())))
        else:
            #textdata = textdata.map(lambda fnames: tf.strings.split(fnames, os.path.sep))
            #data_image = textdata.map(lambda fnames: self.read(fnames[0]))

            if self.config.task == 'classification' or self.config.task == 'regression':
                data_image = tf.data.Dataset.from_tensor_slices(self.file_path)
                data_image = data_image.map(lambda fnames: self.read(fnames))
         
                data_label = textdata.map(lambda fnames: self.process_path(fnames))

我发现此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [224,224,3] and element 2 had shape [2843,2622,3].

0 个答案:

没有答案