如何使用支持生成器的Model.fit(在fit_generator弃用之后)

时间:2019-12-17 18:54:18

标签: python tensorflow keras

我在Tensorflow中使用Model.fit_generator时收到了此弃用警告:

WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.

如何使用Model.fit代替Model.fit_generator

3 个答案:

答案 0 :(得分:8)

如tensorflow文档中所述:

x:输入数据。

  1. 可能是:Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。
    1. TensorFlow张量或张量列表(如果模型具有多个输入)。
    2. 将输入名称映射到相应数组/张量的字典,如果 模型已命名输入。
    3. tf.data数据集。应该返回(输入,目标)或(输入,目标,sample_weights)的元组
    4. 生成器或keras.utils.Sequence返回(输入,目标)或(输入,目标,样本权重)。下面给出了针对迭代器类型(数据集,生成器,序列)的拆包行为的详细说明。

您可以简单地将生成器传递给 Model.fit ,类似于 Model.fit_generator

data_gen_train = ImageDataGenerator(rescale=1/255.)

data_gen_valid = ImageDataGenerator(rescale=1/255.)

train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")

valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")

model.fit(train_generator, epochs=2, validation_data=valid_generator)

答案 1 :(得分:0)

rc1 中的tensorflow 2.1.0开始不推荐使用

Model.fit_generator。 您可以在以下位置找到tf-2.1.0-rc1的文档:https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit

您可以看到Model.fit的第一个参数可以使用生成器,因此只需将其传递给生成器即可。

答案 2 :(得分:0)

文档说,如果使用生成器 x = 作为生成器,则 y = 不应该指定与仅传递KeyHandlingContentView实例的fit_generator签名一致。