具有扩展Sequence的生成器的Keras fit_generator()返回的样本数大于总数

时间:2019-07-02 16:25:43

标签: python tensorflow machine-learning keras conv-neural-network

我正在用Keras训练神经网络。由于数据集的大小,我需要使用一个生成器和fit_generator()方法。我正在关注本教程:

https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

但是,我准备了一个小例子来检查在每个时期馈入网络的样本,看来数量大于样本数量。

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, files, batch_size=2, dim=(160, 160), n_channels=3,
                 n_classes=2, shuffle=False):
        'Initialization'
        self.dim = dim
        self.files = files
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        print ("Number of batches per epoch")
        print(int(np.floor(len(self.files) / self.batch_size)))
        return int(np.floor(len(self.files) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        files_temp = [self.files[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(files_temp)

        return X, y


    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.files))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)


    def __data_generation(self, files_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(files_temp):
            # Store sample
            X[i,] = read_image(ID)

            # Store class
            y[i] = get_label(ID)

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)


...

params = {'dim': (160, 160),
              'batch_size': 2,
              'n_classes': 2,
              'n_channels': 3,
              'shuffle': True}


gen_train = DataGenerator(files, **params)
 model.fit_generator(gen_train, steps_per_epoch=ceil(num_samples_train)/batch_size, validation_data=None,
        epochs = 1,  verbose=1,
    callbacks = [tensorboard])

read_imageget_label是我获取数据的方法。这些方法包括要加载的图像的print(),我得到的超出了我的期望。例如:

num_samples = 10 batch_size = 2

每个时期的步数等于5,这就是keras进度条显示的内容,但是我得到了更多的图像(由于方法内部的打印,我知道这一点)。

我尝试调试,发现__getitem__函数被调用了5次以上!前五次索引的索引介于0到4之间(如预期),但是随后我将得到重复的索引并加载更多数据。

知道为什么会这样吗?我已调试到keras中的data_utils.py,但找不到将索引传递到__getitem__的确切位置。 getitem中的所有内容似乎都正常运行。

1 个答案:

答案 0 :(得分:1)

这很正常,对于steps_per_epoch = 5,您的__getitem__将在每个时期被调用5次 。因此,当然,一个以上的纪元意味着它将被调用多次,而不仅仅是5个。

还请注意,其中涉及并行性,Keras会在另一个线程/进程中自动运行Sequence(取决于配置),因此可能会按预期顺序调用它们。这也是正常的。