我正在学习TensorFlow(TF),它只有一天,所以如果我的疑问太基本无法问我,我会提前道歉。 我正在TF官方网站上研究linear classification example。
作者定义了一个名为input_fun
的函数来读取数据。功能如下:
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
我无法理解倒数第二行。一次性迭代器只调用get_next()
一次但是不应该多次迭代数据(即行数次)以提取行,如this example here?
答案 0 :(得分:2)
所以在这里,get_next()基本上是一个dequeue op。当您使用(使用/调用)get_next()调用的元素时,数据处于队列中,它将从队列中删除,下一个图像/标签将移动到其位置,下次调用它时会出列
目前,此函数仅返回用于出列元素的tensorflow操作,您可以在训练循环中使用它。