Iterator.get_next方法背后的直觉是什么?

时间:2019-10-30 15:29:09

标签: tensorflow tensorflow-datasets tensorflow2.0

方法get_next()的名称有点误导。文档说

  

返回tf.Tensor的嵌套结构,代表下一个元素。

     

在图形模式下,通常应一次调用此方法,并将其结果用作另一计算的输入。然后,典型的循环将对该计算的结果调用tf.Session.run。当Iterator.get_next()操作引发tf.errors.OutOfRangeError时,循环将终止。以下框架显示了构建训练循环时如何使用此方法:

dataset = ...  # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)

with tf.compat.v1.Session() as sess:
  try:
    while True:
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    pass

Python还具有一个名为next的函数,每次我们需要迭代器的下一个元素时都需要调用该函数。但是,根据上面引用的get_next()的文档,get_next()应该仅被调用一次,并且其结果应该通过调用会话的方法run进行评估,因此这有点不太直观,因为我已经习惯了Python的内置函数next。在this script中,get_next()也仅被调用,并且在计算的每个步骤中都会评估调用的结果。

get_next()背后的直觉是什么,它与next()有何不同?我认为在每次我通过调用方法get_next()来评估对run的第一次调用的结果时,都会检索数据集(或可迭代迭代器)的下一个元素,在我上面链接的第二个示例中,但这有点不直观。我不明白为什么即使在阅读了文档中的注释之后,我们也不需要在计算的每一步都调用get_next(以获得可迭代迭代器的下一个元素)

  

注意:多次拨打Iterator.get_next()是合法的,例如在单个步骤中将不同元素分配到多个设备时。但是,当用户在其训练循环的每次迭代中调用Iterator.get_next()时,就会出现一个常见的陷阱。 Iterator.get_next()向图中添加操作,执行每个操作会分配资源(包括线程);结果,在训练循环的每次迭代中调用它都会导致速度减慢并最终耗尽资源。为了防止此结果,当使用次数超过可疑的固定阈值时,我们会记录警告。

通常,不清楚迭代器的工作方式。

1 个答案:

答案 0 :(得分:1)

这个想法是get_next向图中添加了一些操作,这样,每当对它们进行求值时,就可以获取数据集中的下一个元素。在每次迭代中,您只需要运行get_next进行的操作,就无需一次又一次地创建它们。

也许获得直觉的好方法是尝试自己编写一个迭代器。考虑如下内容:

import tensorflow as tf
tf.compat.v1.disable_v2_behavior()

# Make an iterator, returns next element and initializer
def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(0)
    # Check we are not out of bounds
    with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
        # Get next value
        next_val_1 = data[i]
    # Update index after the value is read
    with tf.control_dependencies([next_val_1]):
        i_updated = tf.compat.v1.assign_add(i, 1)
        with tf.control_dependencies([i_updated]):
            next_val_2 = tf.identity(next_val_1)
    return next_val_2, i.initializer

# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
    # Example data
    data = tf.constant([1, 2, 3, 4])
    # Make operations that give you the next element
    next_val, iter_init = iterator_next(data)
    # Initialize iterator
    sess.run(iter_init)
    # Iterate until exception is raised
    while True:
        try:
            print(sess.run(next_val))
        # assert throws InvalidArgumentError
        except tf.errors.InvalidArgumentError: break

输出:

1
2
3
4

在这里,iterator_next为您提供的内容与迭代器中get_next所提供的功能相当,外加初始化程序操作。每次运行next_val时,您都会从data中获得一个新元素,不需要每次都调用该函数(这就是next在Python中的工作方式),只需调用一次然后多次评估结果。

编辑:上面的功能iterator_next也可以简化为以下内容:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    # Start from -1
    i = tf.Variable(-1)
    # First increment i
    i_updated = tf.compat.v1.assign_add(i, 1)
    with tf.control_dependencies([i_updated]):
        # Check i is not out of bounds
        with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
            # Get next value
            next_val = data[i]
    return next_val, i.initializer

或更简单:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(-1)
    i_updated = tf.compat.v1.assign_add(i, 1)
    # Using i_updated directly as a value is equivalent to using i with
    # a control dependency to i_updated
    with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
        next_val = data[i_updated]
    return next_val, i.initializer