如何使用Keras fit_generator批量培训CNN?

时间:2017-08-24 10:32:10

标签: keras generator conv-neural-network training-data

如果这是提出我的问题的错误地方,请道歉(如果情况如此,请帮助我提出最好的提案)。我是Keras和Python的新手,所以希望回答记住这一点。

我正在尝试训练以图像为输入的CNN转向模型。它是一个相当大的数据集,所以我创建了一个数据生成器来使用fit_generator()。我不清楚如何使这种方法在批次上进行训练,所以我假设发生器必须将批次返回到fit_generator()。生成器看起来像这样:

def gen(file_name, batchsz = 64):
    csvfile = open(file_name)
    reader = csv.reader(csvfile)
    batchCount = 0
    while True:
        for line in reader:
            inputs = []
            targets = []
            temp_image = cv2.imread(line[1]) # line[1] is path to image
            measurement = line[3] # steering angle
            inputs.append(temp_image)
            targets.append(measurement)
            batchCount += 1
            if batchCount >= batchsz:
                batchCount = 0
                X = np.array(inputs)
                y = np.array(targets)
                yield X, y
        csvfile.seek(0)

它读取包含遥测数据(转向角等)的csv文件和图像样本的路径,并返回大小的数组:batchsz 对fit_generator()的调用如下所示:

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator
try:
    model.fit_generator(
        tgen,
        samples_per_epoch=113526,
        nb_epoch=6,
        validation_data=vgen,
        nb_val_samples=20001
    )

数据集包含113526个采样点,但模型训练更新输出如下所示(例如):

  1020/113526 [..............................] - ETA: 27737s - loss: 0.0080
  1021/113526 [..............................] - ETA: 27723s - loss: 0.0080
  1022/113526 [..............................] - ETA: 27709s - loss: 0.0080
  1023/113526 [..............................] - ETA: 27696s - loss: 0.0080

哪个似乎是按样品训练样品(随机?)。 结果模型没用。我以前使用.fit()训练了一个小得多的数据集,整个数据集被加载到内存中,这产生了一个至少可以工作的模型。显然我的fit_generator()方法有问题。非常感谢对此的一些帮助。

1 个答案:

答案 0 :(得分:2)

此:

for line in reader:
    inputs = []
    targets = []

...正在为csv文件中的每一行重置批处理。您没有使用整个数据进行培训,但只使用128个样本。

建议:

for line in reader:

    if batchCount == 0:
        inputs = []
        targets = []  
    ....
    ....

有人评论说,in fit生成器samples_per_epoch应该等于total_samples / batchsz

尽管如此,我认为你的损失无论如何都会下降。如果不是,代码中可能还有另一个问题,可能是加载数据的方式,或者模型的初始化或结构。

尝试绘制图像并在生成器中打印数据:

for X,y in tgen: #careful, this is an infinite loop, make it stop

    print(X.shape[0]) # is this really the number of batches you expect?

    for image in X:
        ...some method to plot X so you can see it, or just print     

    print(y)

检查所产生的值是否与您期望的一样。