我正在读这样的TfRecordData
:
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
def _parse_function(example_proto):
features = {
'data': tf.VarLenFeature(tf.float32),
'label': tf.VarLenFeature(tf.float32),
'resolution': tf.FixedLenFeature([], tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
resolution = tf.cast(parsed_features['resolution'], tf.int32)
tensor_feature1 = tf.sparse_tensor_to_dense(parsed_features['data'])
tensor_feature2 = tf.sparse_tensor_to_dense(parsed_features['label'])
...
input = tf.reshape(tensor_feature1, [1, 256, 256])
output = tf.reshape(tensor_feature2, [1, 256, 256])
return input, output
这里我只能一次解析一个功能。是否可以连接我的功能,以便以某种方式堆叠输入样本:
for i in range(0,20)
parsed_features = tf.parse_single_example(example_proto, features)
tensor_feature1 = tf.sparse_tensor_to_dense(parsed_features['data'])
inputs = tf.stack(tensor_feature1, axis=0) # shape = [20, 256, 256]
return inputs, output
**编辑**
我取得了进步:
datasets = []
for idx in range(20):
dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
dataset = dataset.skip(idx)
dataset = dataset.map(_parse_function, num_parallel_calls=tf.constant(FLAGS.num_parallel_calls, dtype=tf.int32))
dataset = dataset.batch(20)
datasets.append(dataset)
由于批次只给我连续批次,我使用.skip(idx)
希望抵消我的出发点,以便......:/ / p>
[1,2,3,4,5] --> [1,2][2,3][3,4] rather than: [1,2][3,4]
我不确定这是否正确。现在唯一的问题是:我还加载了20个输出而不是一个输出。我正在考虑使用zip
,但现在我无法完成它。
答案 0 :(得分:0)
实现此目的的最简单方法是在Dataset.map()
中执行每元素计算,然后使用Dataset.batch()
将连续元素堆叠在一起:
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function) # Using _parse_function from your question.
dataset = dataset.batch(20) # Stack together 20 consecutive elements.