如何在tensorflow中按顺序和随机读取tf.data.Iterator?

时间:2018-04-24 15:38:07

标签: python tensorflow dataset

我在tensorflow中构建了我的tfrecords数据库。现在我想读取记录,使得起点是一些随机值,比如在10到2000之间,然后按顺序读取多个记录,比如在100到200之间。如何使用tf.data.iterator或任何替代方法在张量流中。

非常感谢任何帮助!!

1 个答案:

答案 0 :(得分:2)

您可以使用tf.data.Dataset.taketf.data.Dataset.skip进行此操作。

例如,按如下方式构建您的tf.data.Dataset对象:

starting_point = tf.random_uniform(shape=[], dtype=tf.int64, minval=10, maxval=2000)
num_records = tf.random_uniform(shape=[], dtype=tf.int64, minval=100, maxval=200)

ds = tf.data.TFRecordDataset(...).skip(starting_point).take(num_records)

然后你可以像任何数据集一样构造一个迭代器和“下一个值”张量。例如:

itr = ds.make_one_shot_iterator()
(x, y) = itr.get_next()

希望有所帮助。