如何在fit_generator中为3D数据提供class_weight?

时间:2019-07-02 21:50:09

标签: python keras

我需要在路面图像上执行语义分割任务。我的数据集不平衡。我正在使用批处理大小为5(由于内存有限)的生成器将数据传递到fit_generator,以使模型适合我的大型数据集。我正在尝试使用class_weight,但它给了我错误:

  

ValueError:3个以上尺寸的目标不支持class_weight

我读到sample_weight用于3D数据(我不完全知道它是如何工作的)。但是,fit_generator没有此参数(sample_weight)。谁能帮助我解决这个问题?

我的训练数据集包含41000张图像,大小为(512,512),并带有12个标签。我的生成器产生的图像尺寸为(batch_size, 512, 512, 1),标签的尺寸为(batch_size, 512, 512, 12)

autoencoder.fit_generator(train_gen,
                          epochs= 200,
                          steps_per_epoch = NO_OF_TRAINING_IMAGES//BATCH_SIZE,
                          shuffle= True,
                          validation_data= val_gen,
                          validation_steps= NO_OF_VAL_IMAGES//BATCH_SIZE,
                          callbacks= callbacks_list,
                          class_weight= class_weight,
                          use_multiprocessing=True,
                          workers=8)

0 个答案:

没有答案