我想使用tensorflow(1.13.1)读取csv数据。数据的维数为429,代码为:
class VoiceDataGenerator(DataGenerator):
def __init__(self):
pass
@staticmethod
def parse_data(x, n_classes):
return x[:-1], tf.one_hot(indices=tf.cast(x[-1], tf.int32), depth=n_classes)
@staticmethod
def load_dataset(batch_size, cpu_cores, dataset_path):
dataset_train = tf.data.experimental.CsvDataset(dataset_path + 'train.csv', [tf.float32] * 430, header=False,
field_delim=' ')
dataset_val = tf.data.experimental.CsvDataset(dataset_path + 'test.csv', [tf.float32] * 430, header=False,
field_delim=' ')
n_sample_train = 1019915
total_batches_train = n_sample_train // batch_size + 1
n_sample_val = 57909
n_classes = 1928
dataset_train = dataset_train.shuffle(buffer_size=100000)
dataset_train = dataset_train.map(map_func=lambda *x: VoiceDataGenerator.parse_data(x, n_classes),
num_parallel_calls=cpu_cores)
dataset_train = dataset_train.batch(batch_size)
dataset_train = dataset_train.prefetch(buffer_size=1)
dataset_val = dataset_val.map(map_func=lambda *x: VoiceDataGenerator.parse_data(x, n_classes),
num_parallel_calls=cpu_cores)
dataset_val = dataset_val.batch(batch_size)
dataset_val = dataset_val.prefetch(buffer_size=1)
return dataset_train, dataset_val, total_batches_train, n_sample_train, n_sample_val
其超类是:
class DataGenerator:
@staticmethod
def dataset_iterator(dataset_train, dataset_val):
vgg_iter = tf.data.Iterator.from_structure(dataset_train.output_types, dataset_train.output_shapes)
x, y = vgg_iter.get_next()
# initializer for train_data
train_init = vgg_iter.make_initializer(dataset_train)
test_init = vgg_iter.make_initializer(dataset_val)
return train_init, test_init, x, y
上述类由by调用(位于一个类中,并且params由属性传输)
def load_dataset(self, config):
dataset_train, dataset_val, self.total_batches_train, self.n_samples_train, self.n_samples_val = VoiceDataGenerator.load_dataset(
config['basic'].getint('batch_size'), config['basic'].getint('cpu_cores'), self.path)
self.train_init, self.test_init, self.X, self.Y = VoiceDataGenerator.dataset_iterator(dataset_train, dataset_val)
,但是返回的数据x
的形状为[dimension,batch_size]而不是[batch_size,Dimensions]。我想知道为什么以及如何处理它。