我对Keras完全陌生,我只工作了几天,所以我没有经验。
我能够训练适用于某类的U-Net网络,然后使用以下代码输入RGB图像和灰度蒙版进行训练:
def train_generator():
while True:
for start in range(0, len(ids_train_split), batch_size):
x_batch = []
y_batch = []
end = min(start + batch_size, len(ids_train_split))
ids_train_batch = ids_train_split[start:end]
for id in ids_train_batch.values:
img_name = 'IMG_'+str(id).split('_')[2]
image_path = os.path.join("input", "train", "{}.JPG".format(str(img_name)))
mca_mask_path = os.path.join("input", "train_mask", "{}.png".format(id))
img = cv2.imread(image_path)
img = cv2.resize(img, (input_size, input_size))
mask_mca = cv2.imread(mca_mask_path, cv2.IMREAD_GRAYSCALE)
mask_mca = cv2.resize(mask_mca, (input_size, input_size))
img = randomHueSaturationValue(img,
hue_shift_limit=(-50, 50),
sat_shift_limit=(-5, 5),
val_shift_limit=(-15, 15))
img, mask = randomShiftScaleRotate(img, mask,
shift_limit=(-0.0625, 0.0625),
scale_limit=(-0.1, 0.1),
rotate_limit=(-0, 0))
img, mask = randomHorizontalFlip(img, mask)
mask = np.expand_dims(mask, axis=2)
x_batch.append(img)
y_batch.append(mask)
x_batch = np.array(x_batch, np.float32) / 255
y_batch = np.array(y_batch, np.float32) / 255
yield x_batch, y_batch
这是我的U-Net模型:
def get_unet_1(pretrained_weights=None, input_shape=(1024, 1024, 3), num_classes=1, learning_rate=0.0001):
inputs = Input(shape=input_shape)
# 1024
down0b = Conv2D(8, (3, 3), padding='same')(inputs)
down0b = BatchNormalization()(down0b)
down0b = Activation('relu')(down0b)
down0b = Conv2D(8, (3, 3), padding='same')(down0b)
down0b = BatchNormalization()(down0b)
down0b = Activation('relu')(down0b)
down0b_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0b)
# 512
down0a = Conv2D(16, (3, 3), padding='same')(down0b_pool)
down0a = BatchNormalization()(down0a)
down0a = Activation('relu')(down0a)
down0a = Conv2D(16, (3, 3), padding='same')(down0a)
down0a = BatchNormalization()(down0a)
down0a = Activation('relu')(down0a)
down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
# 256
down0 = Conv2D(32, (3, 3), padding='same')(down0a_pool)
down0 = BatchNormalization()(down0)
down0 = Activation('relu')(down0)
down0 = Conv2D(32, (3, 3), padding='same')(down0)
down0 = BatchNormalization()(down0)
down0 = Activation('relu')(down0)
down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
# 128
down1 = Conv2D(64, (3, 3), padding='same')(down0_pool)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1 = Conv2D(64, (3, 3), padding='same')(down1)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
# 64
down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2 = Conv2D(128, (3, 3), padding='same')(down2)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
# 32
down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3 = Conv2D(256, (3, 3), padding='same')(down3)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
# 16
down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4 = Conv2D(512, (3, 3), padding='same')(down4)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
# 8
center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
center = BatchNormalization()(center)
center = Activation('relu')(center)
center = Conv2D(1024, (3, 3), padding='same')(center)
center = BatchNormalization()(center)
center = Activation('relu')(center)
# center
up4 = UpSampling2D((2, 2))(center)
up4 = concatenate([down4, up4], axis=3)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
# 16
up3 = UpSampling2D((2, 2))(up4)
up3 = concatenate([down3, up3], axis=3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
# 32
up2 = UpSampling2D((2, 2))(up3)
up2 = concatenate([down2, up2], axis=3)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
# 64
up1 = UpSampling2D((2, 2))(up2)
up1 = concatenate([down1, up1], axis=3)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
# 128
up0 = UpSampling2D((2, 2))(up1)
up0 = concatenate([down0, up0], axis=3)
up0 = Conv2D(32, (3, 3), padding='same')(up0)
up0 = BatchNormalization()(up0)
up0 = Activation('relu')(up0)
up0 = Conv2D(32, (3, 3), padding='same')(up0)
up0 = BatchNormalization()(up0)
up0 = Activation('relu')(up0)
up0 = Conv2D(32, (3, 3), padding='same')(up0)
up0 = BatchNormalization()(up0)
up0 = Activation('relu')(up0)
# 256
up0a = UpSampling2D((2, 2))(up0)
up0a = concatenate([down0a, up0a], axis=3)
up0a = Conv2D(16, (3, 3), padding='same')(up0a)
up0a = BatchNormalization()(up0a)
up0a = Activation('relu')(up0a)
up0a = Conv2D(16, (3, 3), padding='same')(up0a)
up0a = BatchNormalization()(up0a)
up0a = Activation('relu')(up0a)
up0a = Conv2D(16, (3, 3), padding='same')(up0a)
up0a = BatchNormalization()(up0a)
up0a = Activation('relu')(up0a)
# 512
up0b = UpSampling2D((2, 2))(up0a)
up0b = concatenate([down0b, up0b], axis=3)
up0b = Conv2D(8, (3, 3), padding='same')(up0b)
up0b = BatchNormalization()(up0b)
up0b = Activation('relu')(up0b)
up0b = Conv2D(8, (3, 3), padding='same')(up0b)
up0b = BatchNormalization()(up0b)
up0b = Activation('relu')(up0b)
up0b = Conv2D(8, (3, 3), padding='same')(up0b)
up0b = BatchNormalization()(up0b)
up0b = Activation('relu')(up0b)
# 1024
classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b)
model = Model(inputs=inputs, outputs=classify)
model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy'])
if pretrained_weights:
model.load_weights(pretrained_weights)
return model
现在,我必须修改问题并将其设为多类分类器,因此我不再使用蒙版,而是使用两个蒙版。因此,我有两种类型的grasycale掩码(同一列火车img的Mca_mask
和NotMca_mask
),在这种情况下,标准做法是什么?将两个面具合并为一个?
答案 0 :(得分:0)
在这一行上,我们可以看到您的输出层正在应用Sigmoid:
classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b)
这意味着您的所有输出都将转换为[0,1]
之间,而它们之间没有任何依存关系。这就是您要用于多类分类的内容。顺便提一句,将输出层转换为[0,1]
范围的另一种常见方法是应用softmax-这对多类不利,因为随着一个类的置信度增加,其他类肯定会减少。
您的损失函数在此行中定义为二进制交叉熵:
model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy'])
哪种类型适用于所有类型的分类(单分类或多分类),并且要求输出在[0,1]
范围内。
因此,基本上,您现在都可以按照现在的配置进行多类分类。您需要做的就是创建多类标签。例如,如果您的班级是狗,猫,鸟,马,山羊,并且图像中包含狗和猫,则标签将为[1, 1, 0, 0, 0]
,您可以按原样训练网络。< / p>