我正在尝试将图像数据集分成4组,每组4张图像用于创建新的单个图像。最后,我想将所有生成的合成图像批处理为512个批处理。我可以这样做吗?
def read_dataset(filename, mode, batch_size = 512):
def _input_fn():
...
# Parse text lines as comma-separated values (CSV)
dataset = textlines_dataset.apply(
tf.data.experimental.map_and_batch(decode_csv,num_parallel_calls=os.cpu_count(),batch_size = 4))
dataset = dataset.apply(
tf.data.experimental.map_and_batch(augment4,num_parallel_calls=os.cpu_count(),batch_size = batch_size))
dataset = dataset.prefetch(buffer_size=None)
return dataset
return _input_fn