我正在使用TensorFlow处理图像分类模型,并一直在检查所有代码以确保我理解它;除了输入功能的一部分,我对所有这些都了解。
在输入功能中,csv文件(火车/评估数据文件的名称)被转换为两个张量,每列一个。并将图像本身转换为二进制数据。
在父函数make_input_fn中,csv_row不是参数。嵌套在该父函数中的是_input_fn,而嵌套在其中的是DEcode_csv函数。
所以我不明白的是:csv_row不是make_input_fn中的参数,而是decode_csv函数的参数。代码如何知道-需要一种更好的放置方式-csv_row是什么?
我已经在其他地方看到过类似的代码,所以我知道它是正确的,但我只是想了解它的工作原理。
非常感谢任何帮助。
def make_input_fn(csv_of_filenames, batch_size, mode, augment = False):
def _input_fn():
def decode_csv(csv_row):
filename, label = tf.decode_csv(records = csv_row, record_defaults = [[""],[""]])
image_bytes = tf.read_file(filename = filename)
return image_bytes, label
# Create tf.data.dataset from filename
dataset = tf.data.TextLineDataset(filenames = csv_of_filenames).map(map_func = decode_csv)
if augment:
dataset = dataset.map(map_func = read_and_preprocess_with_augment)
else:
dataset = dataset.map(map_func = read_and_preprocess)
if mode == tf.estimator.ModeKeys.TRAIN:
num_epochs = None # indefinitely
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
else:
num_epochs = 1 # end-of-input after this
dataset = dataset.repeat(count = num_epochs).batch(batch_size = batch_size)
return dataset.make_one_shot_iterator().get_next()
return _input_fn