在Keras中正确使用fit_generator

时间:2019-05-20 09:57:22

标签: python keras neural-network

全部欢迎。试图了解fit_generator在keras中的工作原理。

每个文件中都有一个数据集-100张图像和100个标签。

我写了这个生成器:

def GenerateData(self):

    while True:

        complete_x1 = np.zeros((500, 50, 50, 3))
        complete_x2 = np.zeros((500, 50, 50, 3))
        complete_y1 = np.zeros((500, 3))
        complete_y2 = np.zeros((500, 2))

        done = 0

        while done < 500:

            data = np.load("{}/data_resized_{}.npy".format(self._patch, self._LastID))

            self.Log('\nLoad ALL data. ID: {} - Done: {}'.format(self._LastID, done))

            for data_x1, data_x2, data_y1, data_y2 in data:

                data_x1 = self.random_transform(data_x1)

                data_x2 = self.random_transform(data_x2)

                data_x1 = self.ImageProcessing(data_x1, 0)

                data_x2 = self.ImageProcessing(data_x2, 1)

                data_x1 = np.array(data_x1).astype('float32')
                data_x1 /= 255

                data_x2 = np.array(data_x2).astype('float32')
                data_x2 /= 255

                complete_x1[done] = data_x1
                complete_x2[done] = data_x2

                complete_y1[done] = data_y1
                complete_y2[done] = data_y2

                done += 1

            self._LastID += 1

            if self._LastID >= 1058:
                self._LastID = 0

        yield [np.array(complete_x1), np.array(complete_x2)], [np.array(complete_y1), np.array(complete_y2)]

我总共有1058个文件。可以显示105800张带有标签的图像。

模型训练:

model.fit_generator(data.GenerateData(), samples_per_epoch=1058/500, nb_epoch=15, verbose=1, workers=1)

一切似乎都不错,但是!

培训开始时,GenerateData打印以下内容:

  

加载所有数据。 ID:0-完成:0

     

加载所有数据。 ID:1-完成:100

     

加载所有数据。 ID:2-完成:200

     

加载所有数据。 ID:3-完成:300

     

加载所有数据。 ID:4-完成:400

     

加载所有数据。 ID:5-完成:0

     

加载所有数据。 ID:6-完成:100

     

加载所有数据。 ID:7-完成:200

     

加载所有数据。 ID:8-完成:300

     

加载所有数据。 ID:9-完成:400

     

加载所有数据。 ID:10-完成:0

这发生在ID为59的文件之前。事实证明...它会跳过上至59文件的所有内容吗? 5900张图片?

它只加载500张图像,然后通过 产生并重新开始,并以他完成的文件的ID为准,但是火车不起作用。

以下是第59个文件之后的内容:

  

加载所有数据。 ID:59-完成:400 1/2 [=============> .......]   -ETA:4秒-损失:2.8177-density_18_loss:2.0145-density_21_loss:0.8032-density_18_acc:0.2140-density_21_acc:0.5780加载所有数据。 ID:60-完成:0

     

加载所有数据。 ID:61-完成:100

     

加载所有数据。 ID:62-完成:200

     

加载所有数据。 ID:63-完成:300

     

加载所有数据。 ID:64-完成:400 2/2 [。========================== ..]   -ETA:0秒-损失:2.7260-density_18_loss:1.7077-density_21_loss:1.0183-density_18_acc:0.2720-density_21_acc:0.5890加载所有数据。 ID:65-完成:0

     

加载所有数据。 ID:66-完成:100

为什么会这样?

1 个答案:

答案 0 :(得分:1)

之所以会出现这种现象,是因为将workers设置为1,并且数据生成任务和训练任务在单独的线程上运行。训练任务在主线程上运行,而数据生成任务在单独的线程上运行,其中线程数取决于workers参数。

如果workers参数为0,则数据生成器将在主线程上运行,并且结果将符合您的期望。