使用Tensorflow Dataset API处理时间序列,以批量添加LSTM

时间:2018-11-21 13:42:24

标签: python tensorflow lstm tensorflow-datasets

我有一个可行的解决方案,但是我发现它相对较慢,如果没有更好的方法来做我想做的事情,我会感到惊讶。请参阅下面的问题说明。

当前解决方案:(首先出现,因为有经验的Tensorflow用户从此代码段中推断出我想要做什么可能更简单)

def stack_batch_dim(*x):
    return {key: tf.stack([b[key] for b in x], axis=1) for key in x[0]}


dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(map_func=parse_box2d_protobuf, num_parallel_calls=8)
dataset = dataset.repeat()

# place the "cursors":
datasets = [dataset.skip(t * inter_cursor_space) for t in range(batch_size)]
dataset = tf.data.Dataset.zip(tuple(datasets))

# consume time_depth samples at a time
dataset = dataset.batch(time_depth)

# the actual batch dimension is a tuple and not a tensor
# need to "stack" the tuple into a tensor
dataset = dataset.map(stack_batch_dim)
iterator = dataset.make_initializable_iterator()

所以我想训练一个带有一系列数据点的LSTM网络。我拥有的数据存储在tfrecords中(建议使用tensorflow存储数据的方式)。

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(map_func=parse_box2d_protobuf, num_parallel_calls=8)

# parse_box2d_protobuf is a function that transforms raw bytes into
# a python structure of the form:
# {"positions": tensor_of_shape_4,
#  "speeds": tensor_of_shape_4,
#  "action": tensor_of_shape_4,
#  "vision": tensor_of_shape_h_w_3,
#  "index": tensor_int64}


iterator = dataset.make_initializable_iterator()
next_sample = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(next_sample))
    # prints a dictionary as expected.

现在我要做的是读取该序列,以便将其批量添加到LSTM中。 如您所料,字典中的“索引”键对应于一个整数,该整数是序列中数据样本的索引(从0开始):

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(10, 15):
        print(sess.run(next_sample["index"]))
        # prints successively 0, 1, 2, 3 and 4

现在让我们说我希望批量大小为4,并且我想展开LSTM网络进行3次迭代。这意味着我需要同时在4个位置(4个游标)读取序列,每个游标读取样本3乘3。让我们添加一个参数来控制游标之间的间隔:

batch_size = 4
time_depth = 3
inter_cursor_space = 5

如何定义迭代器,以便在会话中调用iterator.get_next()["index"]返回

with tf.Session() as sess:
    sess.run(iterator.initializer)

    # first call:
    print(sess.run(next_sample["index"]))

    # should print
    # [[0, 5, 10, 15],
    #  [1, 6, 11, 16],
    #  [2, 7, 12, 17]]

    # second call:
    print(sess.run(next_sample["index"]))

    # should print
    # [[3, 8, 13, 18],
    #  [4, 9, 14, 19],
    #  [5, 10, 15, 20]]

    # shape = [time_depth, batch_size]

此类this迭代器与tf.nn.static_rnn函数兼容。


0 个答案:

没有答案