我正在尝试使用tf.data.datasets建立Tensorflow管道,以便将一些TFRecord加载到Keras模型中。这些数据是多元时间序列。
我当前正在使用Tensorflow 2.0
首先,我从TFRecord获取数据集并进行解析:
dataset = tf.data.TFRecordDataset('...')
context_features = {...}
sequence_features = {...}
def _parse_function(example_proto):
_, sequence = tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
return sequence
dataset = dataset.map(_parse_function)
现在的问题是,它给了我一个MapDataset,其中包含EagerTensor的字典:
for data in dataset.take(3):
print(type(data))
<class 'dict'>
<class 'dict'>
<class 'dict'>
# which look like : {feature1 : EagerTensor, feature2 : EagerTensor ...}
由于这些字典,我似乎无法设法将这些数据进行批处理,改组……以便以后在LSTM层中使用它们。例如:
def make_window_dataset(ds, window_size=5, shift=1, stride=1):
windows = ds.window(window_size, shift=shift, stride=stride)
def sub_to_batch(sub):
return sub.values().batch(window_size, drop_remainder=True)
windows = windows.flat_map(sub_to_batch)
return windows
ds = make_window_dataset(dataset, 10)
gives me :
AttributeError: 'dict_values' object has no attribute 'batch'
感谢您的帮助。我将基于此和其他Tensorflow帮助器进行研究:
https://www.tensorflow.org/guide/data#time_series_windowing
编辑:
我找到了解决问题的方法。我最终使用解析函数中的堆栈将解析给出的字典转换为(None,11)形状的Tensor:
def _parse_function(example_proto):
# Parse the input `tf.Example` proto using the dictionary above.
_, sequence = tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
return tf.stack(list(sequence.values()), axis=1)
答案 0 :(得分:0)
即使在问题部分中也提供了解决方案(答案部分),也可以为社区带来好处。
使用 parse_function 中的(None,11)
将字典转换为形状为tf.stack
的张量已解决了该问题。
更改密码
def _parse_function(example_proto):
_, sequence = tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
return sequence
到
def _parse_function(example_proto):
_, sequence = tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
return tf.stack(list(sequence.values()), axis=1)