我正在尝试使用Tensorflow 2.0在Google Colab上训练pix2pix网络,并且正在使用tf.data.Dataset来导入图像数据(总共约4900个)。但是,由于要训练的图像很多,因此我已使用.batch()函数将数据集分为多个批次。但是,我不知道在训练时一次只使用其中一个批次的方法,而不必将整个数据集加载到RAM中。通常,我在这里https://www.tensorflow.org/beta/tutorials/generative/pix2pix
处遵循代码当我在训练函数中传递整个数据集以进行迭代时,RAM变得疯狂,有时整个笔记本崩溃了。我不确定如何只为每个时期传递一个随机批次来解决此问题。
这就是我加载数据的方式。最后,每个批次的形状为(64,256,256,3)。
train_dataset = tf.data.Dataset.list_files(workpath + '/train/*.jpg')
train_dataset = train_dataset.shuffle(4900, seed=23).take(4900)
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(64)
我的火车功能如下所示。
def train(dataset, epochs):
for epoch in range(epochs):
for input_image, target in dataset:
gen, disc = train_step(input_image, target)
print('Gen loss: {} Disc loss: {}'.format(gen, disc))
train(train_dataset, 60)
如果我将整个train_dataset传递给train函数,它将在每个时期循环遍历所有批次。我试图找出一种方法,只为每个时期传递不同的批次。我能想到的唯一方法是,重新组合批次并为每个时期从整个数据集中挑选一个,但这并不能真正解决内存问题,因为我仍然必须将整个数据集传入。这样做的方式?