如何仅通过tf.data.Dataset使用特定批次的数据

时间:2019-06-19 01:25:40

标签: tensorflow machine-learning keras google-colaboratory tensorflow-datasets

我正在尝试使用Tensorflow 2.0在Google Colab上训练pix2pix网络,并且正在使用tf.data.Dataset来导入图像数据(总共约4900个)。但是,由于要训练的图像很多,因此我已使用.batch()函数将数据集分为多个批次。但是,我不知道在训练时一次只使用其中一个批次的方法,而不必将整个数据集加载到RAM中。通常,我在这里https://www.tensorflow.org/beta/tutorials/generative/pix2pix

处遵循代码

当我在训练函数中传递整个数据集以进行迭代时,RAM变得疯狂,有时整个笔记本崩溃了。我不确定如何只为每个时期传递一个随机批次来解决此问题。

这就是我加载数据的方式。最后,每个批次的形状为(64,256,256,3)。

train_dataset = tf.data.Dataset.list_files(workpath + '/train/*.jpg')


train_dataset = train_dataset.shuffle(4900, seed=23).take(4900)

train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dataset = train_dataset.batch(64)

我的火车功能如下所示。


def train(dataset, epochs):

  for epoch in range(epochs):    
    for input_image, target in dataset:
      gen, disc = train_step(input_image, target)
          print('Gen loss: {} Disc loss: {}'.format(gen, disc))

train(train_dataset, 60)

如果我将整个train_dataset传递给train函数,它将在每个时期循环遍历所有批次。我试图找出一种方法,只为每个时期传递不同的批次。我能想到的唯一方法是,重新组合批次并为每个时期从整个数据集中挑选一个,但这并不能真正解决内存问题,因为我仍然必须将整个数据集传入。这样做的方式?

0 个答案:

没有答案