Keras fit_generator一次训练一个样本,而我从生成器中获取多个样本

时间:2019-06-14 12:31:43

标签: python machine-learning keras deep-learning

我正在使用keras训练模型。我尝试了“ fit”和“ fit_generator”功能。而且我不明白为什么性能会有很多差异,可能是我做错了。这是我第一次编写batch_generator代码。

鉴于批次大小为10,我观察到的是使用function-

fit :它训练得更快(每个时期约3分钟),详细数量随着批大小的倍数而增加(此处为10)
样品- 80/7632 [..............................]-ETA:4:31-损失:2.2072-acc:0.4375 < / p>

fit_generator :训练速度慢得多(每个时期10分钟),详细计数一次增加1(不等于批量) >
样品- 37/7632 [....................................]-ETA:42:25-损失:2.1845-acc:0.3676 < / p>

如您所见,对于同一数据集,fit_generator的ETA太高。并且fit_generator每次增加1,而fit以10的倍数增加

生成器:

def batch_generator(X ,y, batch_size=10):
    from sklearn.utils import shuffle

    batch_count = int(len(X) / batch_size)
    extra = len(X) - (batch_count * batch_size)

    while 1:
        #shuffle X and y
        X_train, y_train = shuffle(X,y)

        #Yeild Batches
        for i in range(1, batch_count):
            batch_start = (i-1) * batch_size
            batch_end = i * batch_size
            X_batch = X_train[batch_start: batch_end]
            y_batch = y_train[batch_start: batch_end]
            yield X_batch, y_batch

        #Yeild Remaining Data less than batch size
        if(extra > 0):
            batch_start = batch_count * batch_size
            X_batch = X_train[batch_start: -1]
            y_batch = y_train[batch_start: -1]
            yield X_batch, y_batch

调整功能:

model.fit_generator(batch_generator(X, y, 10),
                    verbose = 1,
                    samples_per_epoch = len(X),
                    epochs = 20,
                    validation_data = (X_test, y_test),
                    callbacks = callbacks_list)

谁能解释为什么会这样?

1 个答案:

答案 0 :(得分:1)

fit_generator不使用示例,它使用步骤,您正在使用带有samples_per_epoch参数的旧Keras API,这是不正确的并且会产生错误的结果。正确的fit_generator调用是:

model.fit_generator(batch_generator(X, y, 10),
                    verbose = 1,
                    steps_per_epoch = int(len(X) / batch_size),
                    epochs = 20,
                    validation_data = (X_test, y_test),
                    callbacks = callbacks_list)

steps_per_epoch控制在宣告时期之前要使用多少步(调用生成器)。应将其设置为总样本数除以批次大小。对于fit_generator,进度栏中的索引将引用步骤(批次),而不是示例,因此您无法直接将它们与fit进度栏中的索引进行比较。