尝试通过概率性地过滤一些数据(每个时期不同)来使数据集中的每个标签具有相同的数量,我遇到了一个意外错误,该错误似乎是来自tensorflow数据集工具的重复函数的实现错误。如果我们在重复之前调用filter,并且该过滤器不幸地滤除了一个纪元(不一定是第一个纪元)的所有数据,它将以“ OutOfRangeError(回溯见上文):序列结束”错误停止,并且应该因为流仍然可以产生数据。但是,使用重复然后过滤,没有问题。这是再现错误的最少代码:
from tensorflow.python.data.ops import dataset_ops
import tensorflow as tf
def filtering(x):
return tf.less(tf.random_uniform((), minval=0, maxval=1, dtype=tf.float32, seed=42), 0.05)
dataset = dataset_ops.Dataset.range(100).filter(filtering).repeat()
get_next = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10000):
print(sess.run(get_next))
我做错了什么还是实现错误?