Tensorflow数据集API-将窗口应用于多个序列

时间:2019-05-10 14:31:41

标签: python tensorflow tensorflow-datasets

我想设置一个处理顺序数据的数据管道。序列中的每个数据点都具有固定的维数,例如64x64。我有多个长度可变的序列。因此我的数据集可以简化为:

seq1 = np.arange(5)[:, None, None]
seq2 = np.arange(8)[:, None, None]
seq3 = np.arange(7)[:, None, None]
sequences = [seq1, seq2, seq3]

现在,我要对序列中的一系列时间帧进行操作,从而生成3维数据立方体[N_frames,data_dim1,data_dim2]。

对于单个序列,我在TF的window API中发现了Dataset,这使我可以使用窗口构建数据立方体:

window = 3
shift = 1
ds = tf.data.Dataset.from_tensor_slices(seq1)
ds = ds.window(size=window , shift=shift, drop_remainder=True).flat_map(lambda x: x.batch(window))
for d in ds:
    print(d)

产生

tf.Tensor(
[[[0]]

 [[1]]

 [[2]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[1]]

 [[2]]

 [[3]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[2]]

 [[3]]

 [[4]]], shape=(3, 1, 1), dtype=int32)

现在,我很难将这个操作转移到我的全部序列中。 如何从序列集中获取所有数据立方体?

1 个答案:

答案 0 :(得分:0)

我自己找到了答案。我在每个序列上分别使用window函数。我将此过程包装在一个小函数中,然后通过flat_map应用于我的序列集:

sequences = [np.arange(5)[:, None, None], np.arange(20, 24)[:, None, None]]

def get_data_cubes(sequence, size, shift=None, stride=1, drop_remainder=False):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(size=size, shift=shift, stride=stride, drop_remainder=drop_remainder)
    ds = ds.flat_map(lambda x: x.batch(size))
    return ds

window = 3
shift = 1
dataset = tf.data.Dataset.from_generator(lambda: sequences, tf.as_dtype(sequences[0].dtype), tf.TensorShape([None, 1, 1]))
dataset = dataset.flat_map(lambda x: get_data_cubes(x, window, shift=shift, drop_remainder=True))

for d in dataset:
    print(d)

产生

tf.Tensor(
[[[0]]

 [[1]]

 [[2]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[1]]

 [[2]]

 [[3]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[2]]

 [[3]]

 [[4]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[20]]

 [[21]]

 [[22]]], shape=(3, 1, 1), dtype=int32)
tf.Tensor(
[[[21]]

 [[22]]

 [[23]]], shape=(3, 1, 1), dtype=int32)

这正是我搜索的结果。 顺便说一句:该数据集可以像标准TF数据集一样经过改组,批量处理等处理。