如何从SequenceExample TFRecord创建带窗口的多元数据集

时间:2019-10-22 13:37:46

标签: python-3.x tensorflow-datasets tensorflow2.0 tf.keras

我正在尝试使用tf.data.datasets建立Tensorflow管道,以便将一些TFRecord加载到Keras模型中。这些数据是多元时间序列。

我当前正在使用Tensorflow 2.0

首先,我从TFRecord获取数据集并进行解析:

dataset = tf.data.TFRecordDataset('...')

context_features = {...}
sequence_features = {...}

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return sequence


dataset = dataset.map(_parse_function)

现在的问题是,它给了我一个MapDataset,其中包含EagerTensor的字典:

for data in dataset.take(3):
  print(type(data))

<class 'dict'>
<class 'dict'>
<class 'dict'>

# which look like : {feature1 : EagerTensor, feature2 : EagerTensor ...}

由于这些字典,我似乎无法设法将这些数据进行批处理,改组……以便以后在LSTM层中使用它们。例如:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.values().batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows

ds = make_window_dataset(dataset, 10)

gives me :

AttributeError: 'dict_values' object has no attribute 'batch'

感谢您的帮助。我将基于此和其他Tensorflow帮助器进行研究:

https://www.tensorflow.org/guide/data#time_series_windowing

编辑:

我找到了解决问题的方法。我最终使用解析函数中的堆栈将解析给出的字典转换为(None,11)形状的Tensor:

def _parse_function(example_proto):
  # Parse the input `tf.Example` proto using the dictionary above.
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return tf.stack(list(sequence.values()), axis=1)

1 个答案:

答案 0 :(得分:0)

即使在问题部分中也提供了解决方案(答案部分),也可以为社区带来好处。

使用 parse_function 中的(None,11)将字典转换为形状为tf.stack的张量已解决了该问题。

更改密码

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return sequence

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return tf.stack(list(sequence.values()), axis=1)