通过Tensorflow数据集生成器迭代批次

时间:2019-08-07 19:23:56

标签: python tensorflow generator

我说我有

sequence = np.array([[1],[2],[3],[4],[5]])

我将生成器定义为

def generator():
    for el in sequence:
        yield el

现在,我希望使用Tensorflow中定义的from_generator()以便从生成器中检索数据。

dataset = tf.data.Dataset().from_generator(generator,
                                       output_types= tf.int64, 
                                       output_shapes=(tf.TensorShape([1])))
iterator = dataset.make_initializable_iterator()
el = iterator.get_next()

为了进行检索,我用过

with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

有没有一种方法可以使用循环获取“ el”,而不是每次都执行sess.run(el)?

1 个答案:

答案 0 :(得分:1)

这应该可以实现您想要的:

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            print(sess.run(el))
    except tf.errors.OutOfRangeError:
        print("Iterating finished")
        pass