tf.data.experimental.CsvDataset产品以错误的[尺寸,batch_size]形状训练数据

时间:2019-06-20 02:58:50

标签: python tensorflow tensorflow-datasets

我想使用tensorflow(1.13.1)读取csv数据。数据的维数为429,代码为:

class VoiceDataGenerator(DataGenerator):
    def __init__(self):
        pass

    @staticmethod
    def parse_data(x, n_classes):
        return x[:-1], tf.one_hot(indices=tf.cast(x[-1], tf.int32), depth=n_classes)

    @staticmethod
    def load_dataset(batch_size, cpu_cores, dataset_path):
        dataset_train = tf.data.experimental.CsvDataset(dataset_path + 'train.csv', [tf.float32] * 430, header=False,
                                                        field_delim=' ')

        dataset_val = tf.data.experimental.CsvDataset(dataset_path + 'test.csv', [tf.float32] * 430, header=False,
                                                      field_delim=' ')

        n_sample_train = 1019915
        total_batches_train = n_sample_train // batch_size + 1
        n_sample_val = 57909

        n_classes = 1928

        dataset_train = dataset_train.shuffle(buffer_size=100000)
        dataset_train = dataset_train.map(map_func=lambda *x: VoiceDataGenerator.parse_data(x, n_classes),
                                          num_parallel_calls=cpu_cores)

        dataset_train = dataset_train.batch(batch_size)
        dataset_train = dataset_train.prefetch(buffer_size=1)

        dataset_val = dataset_val.map(map_func=lambda *x: VoiceDataGenerator.parse_data(x, n_classes),
                                      num_parallel_calls=cpu_cores)

        dataset_val = dataset_val.batch(batch_size)
        dataset_val = dataset_val.prefetch(buffer_size=1)

        return dataset_train, dataset_val, total_batches_train, n_sample_train, n_sample_val

其超类是:

class DataGenerator:
    @staticmethod
    def dataset_iterator(dataset_train, dataset_val):
        vgg_iter = tf.data.Iterator.from_structure(dataset_train.output_types, dataset_train.output_shapes)
        x, y = vgg_iter.get_next()

        # initializer for train_data
        train_init = vgg_iter.make_initializer(dataset_train)
        test_init = vgg_iter.make_initializer(dataset_val)

        return train_init, test_init, x, y

上述类由by调用(位于一个类中,并且params由属性传输)

    def load_dataset(self, config):
        dataset_train, dataset_val, self.total_batches_train, self.n_samples_train, self.n_samples_val = VoiceDataGenerator.load_dataset(
            config['basic'].getint('batch_size'), config['basic'].getint('cpu_cores'), self.path)
        self.train_init, self.test_init, self.X, self.Y = VoiceDataGenerator.dataset_iterator(dataset_train, dataset_val)

,但是返回的数据x的形状为[dimension,batch_size]而不是[batch_size,Dimensions]。我想知道为什么以及如何处理它。

0 个答案:

没有答案