我正在使用TensorFlow读取TFRecord文件,其中存储了两个2D数组和一个带有浮点值的1D数组。数据存储在几个TFRecord文件中,这些文件代表了分成几个文件的巨大数据集。要读取数据,我正在使用 dataset.map()函数获取具有先前保存的要素形状的已解析要素。
准备TFRecord文件时,我将这些功能另存为浮动列表并展平。另外,我还保存了每个文件的形状,以便以后在读取这些文件时可以正确地读取它。
当我遍历数据而不批量处理数据时,我得到了正常的特征数组。例如:如果我的X(输入)值的要素形状为(10,342),那么我将获得10组342个值。但是,如果我将数据库批处理设置为2,则突然我的要素形状将变为(1、10、342)。就像TensorFlow批处理不能识别出10组数据,而是将其整体处理一样。
我读取TFRecord文件的代码如下:
x_shape = (10, 342)
y_shape = (10, 311)
p_shape = (10,)
def _parse_function(example_proto):
keys_to_features = {'X':tf.FixedLenFeature(x_shape, tf.float32),
'Y':tf.FixedLenFeature(y_shape, tf.float32),
'P':tf.FixedLenFeature(p_shape, tf.float32)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
return parsed_features['X'], parsed_features['Y'], parsed_features['P']
def _load_tfrecord():
training_filenames = [TF_DATABASE_PATH + TF_DATABASE_NAME]
dataset = tf.data.TFRecordDataset(training_filenames)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(2)
init = tf.global_variables_initializer()
iterator = dataset.make_initializable_iterator()
nextElement = iterator.get_next()
with tf.Session() as sess:
sess.run(init)
sess.run(iterator.initializer)
currentBatch = sess.run(nextElement)
所以我的问题是如何使TensorFlow将数据分为正确的批次?在这种情况下,如果我有10组数据(每个X具有342个元素),那么我希望每组有5批2组。
这是我第一次使用TFRecord,非常感谢您的帮助!谢谢。