tf.data.Iterator.get_next():如何在tf.while_loop中前进?

时间:2018-05-08 15:40:47

标签: python python-3.x tensorflow

目前我尝试在Tensorflow while循环中实现所有训练,但是我遇到了Tensorflow数据集API迭代器的问题。

通常,在调用sess.run()时,Iterator.get_next()会前进到下一个元素。 但是,我需要在一次运行中前进到下一个元素。我该怎么做?

以下小例子显示了我的问题:

import tensorflow as tf
import numpy as np


def for_loop(condition, modifier, body_op, idx=0):
    idx = tf.convert_to_tensor(idx)

    def body(i):
        with tf.control_dependencies([body_op(i)]):
            return [modifier(i)]

    # do the loop:
    loop = tf.while_loop(condition, body, [idx])
    return loop


x = np.arange(10)

data = tf.data.Dataset.from_tensor_slices(x)
data = data.repeat()

iterator = data.make_initializable_iterator()
smpl = iterator.get_next()

loop = for_loop(
    condition=lambda i: tf.less(i, 5),
    modifier=lambda i: tf.add(i, 1),
    body_op=lambda i: tf.Print(smpl, [smpl], message="This is sample: ")
)

sess = tf.InteractiveSession()
sess.run(iterator.initializer)
sess.run(loop)

输出:

This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]

我总是得到完全相同的元素。

1 个答案:

答案 0 :(得分:2)

每次你想“在一次运行中迭代”时,你需要调用iterator.get_next()

例如,在您的玩具示例中,只需将body_op替换为:

 body_op=lambda i: tf.Print(i, [iterator.get_next()], message="This is sample: ")
# This is sample: [0]
# This is sample: [1]
# This is sample: [2]
# This is sample: [3]
# This is sample: [4]