方法get_next()
的名称有点误导。文档说
返回
tf.Tensor
的嵌套结构,代表下一个元素。在图形模式下,通常应一次调用此方法,并将其结果用作另一计算的输入。然后,典型的循环将对该计算的结果调用
tf.Session.run
。当Iterator.get_next()
操作引发tf.errors.OutOfRangeError
时,循环将终止。以下框架显示了构建训练循环时如何使用此方法:
dataset = ... # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ... # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)
with tf.compat.v1.Session() as sess:
try:
while True:
sess.run(train_op)
except tf.errors.OutOfRangeError:
pass
Python还具有一个名为next
的函数,每次我们需要迭代器的下一个元素时都需要调用该函数。但是,根据上面引用的get_next()
的文档,get_next()
应该仅被调用一次,并且其结果应该通过调用会话的方法run
进行评估,因此这有点不太直观,因为我已经习惯了Python的内置函数next
。在this script中,get_next()
也仅被调用,并且在计算的每个步骤中都会评估调用的结果。
get_next()
背后的直觉是什么,它与next()
有何不同?我认为在每次我通过调用方法get_next()
来评估对run
的第一次调用的结果时,都会检索数据集(或可迭代迭代器)的下一个元素,在我上面链接的第二个示例中,但这有点不直观。我不明白为什么即使在阅读了文档中的注释之后,我们也不需要在计算的每一步都调用get_next
(以获得可迭代迭代器的下一个元素)
注意:多次拨打
Iterator.get_next()
是合法的,例如在单个步骤中将不同元素分配到多个设备时。但是,当用户在其训练循环的每次迭代中调用Iterator.get_next()
时,就会出现一个常见的陷阱。Iterator.get_next()
向图中添加操作,执行每个操作会分配资源(包括线程);结果,在训练循环的每次迭代中调用它都会导致速度减慢并最终耗尽资源。为了防止此结果,当使用次数超过可疑的固定阈值时,我们会记录警告。
通常,不清楚迭代器的工作方式。
答案 0 :(得分:1)
这个想法是get_next
向图中添加了一些操作,这样,每当对它们进行求值时,就可以获取数据集中的下一个元素。在每次迭代中,您只需要运行get_next
进行的操作,就无需一次又一次地创建它们。
也许获得直觉的好方法是尝试自己编写一个迭代器。考虑如下内容:
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
# Make an iterator, returns next element and initializer
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(0)
# Check we are not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val_1 = data[i]
# Update index after the value is read
with tf.control_dependencies([next_val_1]):
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
next_val_2 = tf.identity(next_val_1)
return next_val_2, i.initializer
# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
# Example data
data = tf.constant([1, 2, 3, 4])
# Make operations that give you the next element
next_val, iter_init = iterator_next(data)
# Initialize iterator
sess.run(iter_init)
# Iterate until exception is raised
while True:
try:
print(sess.run(next_val))
# assert throws InvalidArgumentError
except tf.errors.InvalidArgumentError: break
输出:
1
2
3
4
在这里,iterator_next
为您提供的内容与迭代器中get_next
所提供的功能相当,外加初始化程序操作。每次运行next_val
时,您都会从data
中获得一个新元素,不需要每次都调用该函数(这就是next
在Python中的工作方式),只需调用一次然后多次评估结果。
编辑:上面的功能iterator_next
也可以简化为以下内容:
def iterator_next(data):
data = tf.convert_to_tensor(data)
# Start from -1
i = tf.Variable(-1)
# First increment i
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
# Check i is not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val = data[i]
return next_val, i.initializer
或更简单:
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(-1)
i_updated = tf.compat.v1.assign_add(i, 1)
# Using i_updated directly as a value is equivalent to using i with
# a control dependency to i_updated
with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
next_val = data[i_updated]
return next_val, i.initializer