更详细地了解TfRecordDataset地图函数

时间:2018-03-02 11:00:38

标签: python tensorflow deep-learning tensorflow-datasets

我正在读这样的TfRecordData

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

def _parse_function(example_proto):

    features = {
        'data':         tf.VarLenFeature(tf.float32),
        'label':        tf.VarLenFeature(tf.float32),
        'resolution':   tf.FixedLenFeature([], tf.int64)
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    resolution = tf.cast(parsed_features['resolution'], tf.int32)
    tensor_feature1 = tf.sparse_tensor_to_dense(parsed_features['data'])
    tensor_feature2 = tf.sparse_tensor_to_dense(parsed_features['label'])
    ...

    input = tf.reshape(tensor_feature1, [1, 256, 256])
    output = tf.reshape(tensor_feature2, [1, 256, 256])

    return input, output

这里我只能一次解析一个功能。是否可以连接我的功能,以便以某种方式堆叠输入样本:

for i in range(0,20)
    parsed_features = tf.parse_single_example(example_proto, features)
    tensor_feature1 = tf.sparse_tensor_to_dense(parsed_features['data'])

    inputs = tf.stack(tensor_feature1, axis=0) # shape = [20, 256, 256]
return inputs, output

**编辑**

我取得了进步:

datasets = []

for idx in range(20):
    dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
    dataset = dataset.skip(idx)
    dataset = dataset.map(_parse_function, num_parallel_calls=tf.constant(FLAGS.num_parallel_calls, dtype=tf.int32))
    dataset = dataset.batch(20)
    datasets.append(dataset)

由于批次只给我连续批次,我使用.skip(idx)希望抵消我的出发点,以便......:/ / p>

[1,2,3,4,5] --> [1,2][2,3][3,4] rather than: [1,2][3,4]

我不确定这是否正确。现在唯一的问题是:我还加载了20个输出而不是一个输出。我正在考虑使用zip,但现在我无法完成它。

1 个答案:

答案 0 :(得分:0)

实现此目的的最简单方法是在Dataset.map()中执行每元素计算,然后使用Dataset.batch()将连续元素堆叠在一起:

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)  # Using _parse_function from your question.
dataset = dataset.batch(20)  # Stack together 20 consecutive elements.