我有一个可行的解决方案,但是我发现它相对较慢,如果没有更好的方法来做我想做的事情,我会感到惊讶。请参阅下面的问题说明。
当前解决方案:(首先出现,因为有经验的Tensorflow用户从此代码段中推断出我想要做什么可能更简单)
def stack_batch_dim(*x):
return {key: tf.stack([b[key] for b in x], axis=1) for key in x[0]}
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(map_func=parse_box2d_protobuf, num_parallel_calls=8)
dataset = dataset.repeat()
# place the "cursors":
datasets = [dataset.skip(t * inter_cursor_space) for t in range(batch_size)]
dataset = tf.data.Dataset.zip(tuple(datasets))
# consume time_depth samples at a time
dataset = dataset.batch(time_depth)
# the actual batch dimension is a tuple and not a tensor
# need to "stack" the tuple into a tensor
dataset = dataset.map(stack_batch_dim)
iterator = dataset.make_initializable_iterator()
所以我想训练一个带有一系列数据点的LSTM网络。我拥有的数据存储在tfrecords中(建议使用tensorflow存储数据的方式)。
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(map_func=parse_box2d_protobuf, num_parallel_calls=8)
# parse_box2d_protobuf is a function that transforms raw bytes into
# a python structure of the form:
# {"positions": tensor_of_shape_4,
# "speeds": tensor_of_shape_4,
# "action": tensor_of_shape_4,
# "vision": tensor_of_shape_h_w_3,
# "index": tensor_int64}
iterator = dataset.make_initializable_iterator()
next_sample = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
print(sess.run(next_sample))
# prints a dictionary as expected.
现在我要做的是读取该序列,以便将其批量添加到LSTM中。 如您所料,字典中的“索引”键对应于一个整数,该整数是序列中数据样本的索引(从0开始):
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(10, 15):
print(sess.run(next_sample["index"]))
# prints successively 0, 1, 2, 3 and 4
现在让我们说我希望批量大小为4,并且我想展开LSTM网络进行3次迭代。这意味着我需要同时在4个位置(4个游标)读取序列,每个游标读取样本3乘3。让我们添加一个参数来控制游标之间的间隔:
batch_size = 4
time_depth = 3
inter_cursor_space = 5
如何定义迭代器,以便在会话中调用iterator.get_next()["index"]
返回
with tf.Session() as sess:
sess.run(iterator.initializer)
# first call:
print(sess.run(next_sample["index"]))
# should print
# [[0, 5, 10, 15],
# [1, 6, 11, 16],
# [2, 7, 12, 17]]
# second call:
print(sess.run(next_sample["index"]))
# should print
# [[3, 8, 13, 18],
# [4, 9, 14, 19],
# [5, 10, 15, 20]]
# shape = [time_depth, batch_size]
此类this迭代器与tf.nn.static_rnn
函数兼容。