TensorFlow-读取TFRecord多维数组并将其批处理

时间:2019-02-19 22:47:02

标签: python tensorflow multidimensional-array tfrecord batching

我正在使用TensorFlow读取TFRecord文件,其中存储了两个2D数组和一个带有浮点值的1D数组。数据存储在几个TFRecord文件中,这些文件代表了分成几个文件的巨大数据集。要读取数据,我正在使用 dataset.map()函数获取具有先前保存的要素形状的已解析要素。

准备TFRecord文件时,我将这些功能另存为浮动列表并展平。另外,我还保存了每个文件的形状,以便以后在读取这些文件时可以正确地读取它。

当我遍历数据而不批量处理数据时,我得到了正常的特征数组。例如:如果我的X(输入)值的要素形状为(10,342),那么我将获得10组342个值。但是,如果我将数据库批处理设置为2,则突然我的要素形状将变为(1、10、342)。就像TensorFlow批处理不能识别出10组数据,而是将其整体处理一样。

我读取TFRecord文件的代码如下:

x_shape = (10, 342)
y_shape = (10, 311)
p_shape = (10,)

def _parse_function(example_proto): 
    keys_to_features = {'X':tf.FixedLenFeature(x_shape, tf.float32),
            'Y':tf.FixedLenFeature(y_shape, tf.float32),
            'P':tf.FixedLenFeature(p_shape, tf.float32)}

    parsed_features = tf.parse_single_example(example_proto, keys_to_features)

    return parsed_features['X'], parsed_features['Y'], parsed_features['P']

def _load_tfrecord():
    training_filenames = [TF_DATABASE_PATH + TF_DATABASE_NAME]

    dataset = tf.data.TFRecordDataset(training_filenames)
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(2)

    init = tf.global_variables_initializer()
    iterator = dataset.make_initializable_iterator()
    nextElement = iterator.get_next()

    with tf.Session() as sess:          
        sess.run(init)
        sess.run(iterator.initializer)
        currentBatch = sess.run(nextElement)

所以我的问题是如何使TensorFlow将数据分为正确的批次?在这种情况下,如果我有10组数据(每个X具有342个元素),那么我希望每组有5批2组。

这是我第一次使用TFRecord,非常感谢您的帮助!谢谢。

0 个答案:

没有答案