如何应用DataGenerator训练和验证数据?

时间:2020-06-18 12:50:21

标签: python data-generation

我在这里https://keras.io/api/utils/python_utils/#sequence-class使用代码,对自定义DataGenerator进行了编码。

 # Here, `x_set` is list of path to the images
 # and `y_set` are the associated classes.

class DataGenerator(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (224, 224))
               for file_name in batch_x]), np.array(batch_y)

现在,我想知道如何将数据生成器应用于我的训练数据和验证数据? 我有X_trainX_val,它们是包含我的图像文件的图像路径的列表,而y_trainy_val是包含一个热编码标签的列表。

然后我可以使用此代码吗?

training_generator = DataGenerator(X_train, y_train)
validation_generator = DataGenerator(X_val, y_val)

然后适合模型吗?

model.fit_generator(generator=training_generator,
                    validation_data=validation_generator)

1 个答案:

答案 0 :(得分:1)

您写的基本上是正确的。不要忘记将batch_size参数传递给您的DataGenerator

另一方面,应该将epochs参数(如您在注释中提到的)传递给model.fit_generator(最好使用model.fit,因为fit_generator方法是deprecated)。如果不通过,则epochs的默认值为1。

还请查看this tutorial,以了解如何使用Sequence类(您可以跳至使用DataGenerator的底部)。在本教程中,将batch_size以外的其他两个参数传递给DataGenerator,因为它们被定义为__init__方法的输入。只要您不定义它们,就不必通过它们。