TensorFlow数据集随机播放每个时代

时间:2017-05-23 01:26:36

标签: python tensorflow

在Tensorflow中数据集类的manual中,它显示了如何对数据进行混洗以及如何对其进行批处理。然而,人们如何将每个时代的数据洗牌并不明显。我已经尝试了下面的内容,但数据的顺序与第一个时期的顺序完全相同。有人知道如何使用数据集在时期之间进行混洗吗?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))

2 个答案:

答案 0 :(得分:9)

我的环境:Python 3.6,TensorFlow 1.4。

TensorFlow已将Dataset添加到tf.data

你应该对data.shuffle的位置保持谨慎。在您的代码中,数据的时期已经放在dataset的缓冲区之前shuffle。这是两个可用于混洗数据集的示例。

随机播放所有元素

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

输出:

epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

批次之间的混洗,而不是批量洗牌

# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

输出:

epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]

答案 1 :(得分:1)

在我看来,您在两种情况下都使用相同的next_batch。因此,根据您真正想要的内容,您可能需要在第二次调用next_batch之前重新创建sess.run,如下所示,否则data = data.shuffle(12)对{next_batch没有任何影响您之前在代码中创建的。

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)

"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
    print(sess.run(next_batch))

请告诉我这是否有帮助。谢谢。