如何在Keras Retinanet培训中更改批处理大小

时间:2018-12-12 10:28:19

标签: python-3.x oop keras

我正在尝试训练keras_retinanet模型,如下面的代码所示,并且训练工作正常。我为fit_generator函数创建了一个CSVGenerator数据生成器,该函数继承了“ Generator”超类,该类中有一个名为“ batch_size”的参数,默认为“ 1”。

我想更改此“ batch_size”变量的值,但是我无法弄清楚该怎么做。非常感谢您的帮助。

model = load_model('./snapshots/resnet50_csv_01.h5', 
backbone_name='resnet50')

generator = CSVGenerator(
    csv_data_file='./data_set_retina/train.csv',
    csv_class_file='./data_set_retina/class_id_mapping'
)


generator_val = CSVGenerator(
    csv_data_file='./data_set_retina/val.csv',
    csv_class_file='./data_set_retina/class_id_mapping'
)
model.compile(
    loss={
        'regression'    : keras_retinanet.losses.smooth_l1(),
        'classification': keras_retinanet.losses.focal()
    },
    optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)
)
csv_logger = keras.callbacks.CSVLogger('./logs/training_log.csv', 
separator=',', append=False)

model.fit_generator(generator, steps_per_epoch=80000, epochs=50, 
verbose=1, callbacks=[csv_logger], 

validation_data=generator_val,validation_steps=20000,class_weight=None, 
max_queue_size=10, workers=1, use_multiprocessing=False,
                shuffle=True, initial_epoch=0)

1 个答案:

答案 0 :(得分:1)

我想您是在谈论keras-retinanet存储库。

您可以在此处找到batch size

https://github.com/fizyr/keras-retinanet/blob/b28c358c71026d7a5bcb1f4d928241a693d95230/keras_retinanet/bin/train.py#L395

然后将此变量传递到common_args词典中的生成器。

实际上,也可以实例化传递CSVGenerator的{​​{1}}参数。遵循您的代码段:

batch_size