我正在尝试训练keras_retinanet模型,如下面的代码所示,并且训练工作正常。我为fit_generator函数创建了一个CSVGenerator数据生成器,该函数继承了“ Generator”超类,该类中有一个名为“ batch_size”的参数,默认为“ 1”。
我想更改此“ batch_size”变量的值,但是我无法弄清楚该怎么做。非常感谢您的帮助。
model = load_model('./snapshots/resnet50_csv_01.h5',
backbone_name='resnet50')
generator = CSVGenerator(
csv_data_file='./data_set_retina/train.csv',
csv_class_file='./data_set_retina/class_id_mapping'
)
generator_val = CSVGenerator(
csv_data_file='./data_set_retina/val.csv',
csv_class_file='./data_set_retina/class_id_mapping'
)
model.compile(
loss={
'regression' : keras_retinanet.losses.smooth_l1(),
'classification': keras_retinanet.losses.focal()
},
optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)
)
csv_logger = keras.callbacks.CSVLogger('./logs/training_log.csv',
separator=',', append=False)
model.fit_generator(generator, steps_per_epoch=80000, epochs=50,
verbose=1, callbacks=[csv_logger],
validation_data=generator_val,validation_steps=20000,class_weight=None,
max_queue_size=10, workers=1, use_multiprocessing=False,
shuffle=True, initial_epoch=0)
答案 0 :(得分:1)
我想您是在谈论keras-retinanet存储库。
您可以在此处找到batch size
:
然后将此变量传递到common_args
词典中的生成器。
实际上,也可以实例化传递CSVGenerator
的{{1}}参数。遵循您的代码段:
batch_size