我的代码具有与tensorflow 2.0 tutorial类似的模式。 我希望我的数据集对象在每个时代都重新洗牌。
dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)
for epoch in range(10):
for d in dataset:
print(d)
结果:
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...
似乎数据集并没有在每个时期都随机播放。 我应该为每个纪元调用.shuffle()吗?
答案 0 :(得分:1)
是的,您应该在内循环中调用.shuffle
。此外,最好在可以使用等同于Python语句的纯tf。*方法时不要混合使用python代码和TensorFlow代码。
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)
@tf.function
def loop():
for epoch in tf.range(10):
for d in dataset.shuffle(2):
tf.print(d)
loop()
每次调用循环都会产生不同的值(tf.print
打印tf.Tensor
的内容,与打印对象的print
不同)。