时间序列中的随机批次数据... Tensorflow

时间:2018-08-22 12:14:14

标签: tensorflow random batch-processing rnn

我正在尝试估计Tensorflow中的RNN,需要创建一批数据来提供估计过程。

我想提供随机批次,但是我需要每个随机批次包含不间断的数据。因此,每个批次都是在时间序列中随机开始的,但是包含(比如说20天)不间断的数据。

下面我有一个tensorflow程序,几乎可以完成所有工作...我得到了随机批次,但是每个批次都包含该批次内随机的数据。只需稍稍更改一下代码,就可以使每批数据构成不间断的数据吗?

import tensorflow as tf

num_epochs = 2

# create 2 simple data input 
inc_dataset = tf.data.Dataset.range(12)
dec_dataset = tf.data.Dataset.range(0, -12, -1)

# merge the two data sets
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))

# the only "shuffler" I know in TF 
dataset = dataset.shuffle(buffer_size=10000)

# batches of size 4
dataset = dataset.batch(4)

# repeat the dataset by number of epochs
dataset = dataset.repeat(num_epochs)

# one-shot iterator
sess = tf.Session()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()


while True:
    try:
        print(sess.run(next_element))
    except tf.errors.OutOfRangeError:
        break

输出将是:

(array([0, 3, 5, 4], dtype=int64), array([ 0, -3, -5, -4], dtype=int64))
(array([7, 8, 1, 6], dtype=int64), array([-7, -8, -1, -6], dtype=int64))
(array([ 9,  2, 11, 10], dtype=int64), array([ -9,  -2, -11, -10], dtype=int64))
(array([9, 0, 5, 3], dtype=int64), array([-9,  0, -5, -3], dtype=int64))
(array([4, 8, 1, 2], dtype=int64), array([-4, -8, -1, -2], dtype=int64))
(array([10,  6, 11,  7], dtype=int64), array([-10,  -6, -11,  -7], dtype=int64))

非常感谢您。

Br。

1 个答案:

答案 0 :(得分:0)

好-只需更改指令的顺序即可解决问题,事实证明。这样很简单:

# batches of size 4
dataset = dataset.batch(4)

# the only "shuffler" I know in TF 
dataset = dataset.shuffle(buffer_size=10000)

输出可以是:

(array([ 8,  9, 10, 11], dtype=int64), array([ -8,  -9, -10, 
-11], dtype=int64))
(array([0, 1, 2, 3], dtype=int64), array([ 0, -1, -2, -3], 
dtype=int64))
(array([4, 5, 6, 7], dtype=int64), array([-4, -5, -6, -7], 
dtype=int64))
(array([ 8,  9, 10, 11], dtype=int64), array([ -8,  -9, -10, 
-11], dtype=int64))
(array([4, 5, 6, 7], dtype=int64), array([-4, -5, -6, -7], 
dtype=int64))
(array([0, 1, 2, 3], dtype=int64), array([ 0, -1, -2, -3], 
dtype=int64))

因此,与第一个时期相比,它还会更改下一个时期的批次顺序。