我正在尝试使用以下自定义损失函数和指标来训练UNet在keras
中进行图像分割:
def dice_coef(y_true, y_pred):
'''
Params: y_true -- the labeled mask corresponding to an rgb image
y_pred -- the predicted mask of an rgb image
Returns: dice_coeff -- A metric that accounts for precision and recall
on the scale from 0 - 1. The closer to 1, the
better.
Citation (MIT License): https://github.com/jocicmarko/
ultrasound-nerve-segmentation/blob/
master/train.py
'''
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
smooth = 1.0
return (2.0*intersection+smooth)/(K.sum(y_true_f)+K.sum(y_pred_f)+smooth)
def dice_coef_loss(y_true, y_pred):
'''
Params: y_true -- the labeled mask corresponding to an rgb image
y_pred -- the predicted mask of an rgb image
Returns: 1 - dice_coeff -- a negation of the dice coefficient on
the scale from 0 - 1. The closer to 0, the
better.
Citation (MIT License): https://github.com/jocicmarko/
ultrasound-nerve-segmentation/blob/
master/train.py
'''
return 1-dice_coef(y_true, y_pred)
在对加载到RAM中的图像子集进行训练时,它可以正常工作。但是,当使用flow_from_directory
和fit_generator
从整个数据集中逐步加载一批图像时,我会得到负损失,并且骰子系数大于1。这是生成器的代码:
import keras
from keras.preprocessing.image import ImageDataGenerator
image_datagen = ImageDataGenerator()
mask_datagen = ImageDataGenerator()
seed = 1
image_generator = image_datagen.flow_from_directory(
"train/flow_from_dir/256/dicom/",
color_mode="grayscale",
batch_size=32,
class_mode=None,
seed=seed)
mask_generator = image_datagen.flow_from_directory(
"train/flow_from_dir/256/mask/",
color_mode="grayscale",
batch_size=32,
class_mode=None,
seed=seed)
train_generator = zip(image_generator, mask_generator)
model.compile(Adam(0.001), loss=unet_utils.dice_coef_loss, metrics=[unet_utils.dice_coef])
history = model.fit_generator(train_generator, steps_per_epoch=10712/32, epochs=300, verbose=1,
validation_data=None,
callbacks=[PlotLossesKeras()])
答案 0 :(得分:0)
由于@today的洞察力,我意识到图像和遮罩都被加载为值为0到255的数组。因此,我添加了预处理功能以对其进行归一化,从而解决了我的问题:
image_datagen = ImageDataGenerator(preprocessing_function=lambda x: x/255)
mask_datagen = ImageDataGenerator(preprocessing_function=lambda x: x/255)