Keras模型fit_generator奇数错误

时间:2017-03-15 23:15:03

标签: python deep-learning keras conv-neural-network

我有以下代码:

    datagen = ImageDataGenerator(
        rescale=1./255,
        target_size=(128, 128),
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

    test_datagen = ImageDataGenerator(
        rescale=1./255,
        target_size=(128, 128)
    )

    datagen.fit(X_train)

    model.fit_generator(
        datagen.flow(X_train, Y_train),
        samples_per_epoch=len(X_train),
        epochs=30,
        verbose=1,
        validation_data=(X_valid, Y_valid))

这引发了这个异常错误

  Traceback (most recent call last):
      File "cnn.py", line 258, in <module>
          models = run_cross_validation_create_models(num_folds)
      File "cnn.py", line 205, in run_cross_validation_create_models
          validation_data=(X_valid, Y_valid))
      TypeError: fit_generator() takes at least 4 arguments (5 given)

有人可以解释这里出了什么问题,我正在加载一组3700张图片。

1 个答案:

答案 0 :(得分:3)

它可能来自新API(昨天发布的Keras 2.0),fit_generator()现在需要steps_per_epoch参数而不是samples_per_epoch

steps_per_epoch通常是samples_per_epoch / batch_size

您可以找到此信息in the documentation

有帮助吗?