全部欢迎。试图了解fit_generator在keras中的工作原理。
每个文件中都有一个数据集-100张图像和100个标签。
我写了这个生成器:
def GenerateData(self):
while True:
complete_x1 = np.zeros((500, 50, 50, 3))
complete_x2 = np.zeros((500, 50, 50, 3))
complete_y1 = np.zeros((500, 3))
complete_y2 = np.zeros((500, 2))
done = 0
while done < 500:
data = np.load("{}/data_resized_{}.npy".format(self._patch, self._LastID))
self.Log('\nLoad ALL data. ID: {} - Done: {}'.format(self._LastID, done))
for data_x1, data_x2, data_y1, data_y2 in data:
data_x1 = self.random_transform(data_x1)
data_x2 = self.random_transform(data_x2)
data_x1 = self.ImageProcessing(data_x1, 0)
data_x2 = self.ImageProcessing(data_x2, 1)
data_x1 = np.array(data_x1).astype('float32')
data_x1 /= 255
data_x2 = np.array(data_x2).astype('float32')
data_x2 /= 255
complete_x1[done] = data_x1
complete_x2[done] = data_x2
complete_y1[done] = data_y1
complete_y2[done] = data_y2
done += 1
self._LastID += 1
if self._LastID >= 1058:
self._LastID = 0
yield [np.array(complete_x1), np.array(complete_x2)], [np.array(complete_y1), np.array(complete_y2)]
我总共有1058个文件。可以显示105800张带有标签的图像。
模型训练:
model.fit_generator(data.GenerateData(), samples_per_epoch=1058/500, nb_epoch=15, verbose=1, workers=1)
一切似乎都不错,但是!
培训开始时,GenerateData打印以下内容:
加载所有数据。 ID:0-完成:0
加载所有数据。 ID:1-完成:100
加载所有数据。 ID:2-完成:200
加载所有数据。 ID:3-完成:300
加载所有数据。 ID:4-完成:400
加载所有数据。 ID:5-完成:0
加载所有数据。 ID:6-完成:100
加载所有数据。 ID:7-完成:200
加载所有数据。 ID:8-完成:300
加载所有数据。 ID:9-完成:400
加载所有数据。 ID:10-完成:0
这发生在ID为59的文件之前。事实证明...它会跳过上至59文件的所有内容吗? 5900张图片?
它只加载500张图像,然后通过 产生并重新开始,并以他完成的文件的ID为准,但是火车不起作用。
以下是第59个文件之后的内容:
加载所有数据。 ID:59-完成:400 1/2 [=============> .......] -ETA:4秒-损失:2.8177-density_18_loss:2.0145-density_21_loss:0.8032-density_18_acc:0.2140-density_21_acc:0.5780加载所有数据。 ID:60-完成:0
加载所有数据。 ID:61-完成:100
加载所有数据。 ID:62-完成:200
加载所有数据。 ID:63-完成:300
加载所有数据。 ID:64-完成:400 2/2 [。========================== ..] -ETA:0秒-损失:2.7260-density_18_loss:1.7077-density_21_loss:1.0183-density_18_acc:0.2720-density_21_acc:0.5890加载所有数据。 ID:65-完成:0
加载所有数据。 ID:66-完成:100
为什么会这样?
答案 0 :(得分:1)
之所以会出现这种现象,是因为将workers
设置为1,并且数据生成任务和训练任务在单独的线程上运行。训练任务在主线程上运行,而数据生成任务在单独的线程上运行,其中线程数取决于workers
参数。
如果workers
参数为0,则数据生成器将在主线程上运行,并且结果将符合您的期望。