我创建了一个U-Net类,通过它可以定义深度可变的U-Net。我正在训练它,以便从图像中仅分割一个类,但是问题是,当我使用训练好的模型来推断图像时,返回的数组只有一个和零,我想看看它的实际概率。类。我该如何实现?
我用来创建模型的类:
class UNet():
def __init__(self, depth, shape):
encoder_layers = [[0, 0] for x in range(depth)]
decoder_layers = [0 for x in range(depth)]
inputs = layers.Input(shape=shape, name = 'Input') # Input layer
encoder_layers[0] = self.encoder_block(inputs, 32) # First encoder block
for i in range(1, depth): # Assemble the rest encoder blocks
encoder_layers[i] = self.encoder_block(encoder_layers[i-1][0], 32*2**i)
center = self.conv_block(encoder_layers[-1][0], 32*2**depth) # Center
decoder_layers[0] = self.decoder_block(center, encoder_layers[-1][1], 32*2**(depth-1)) # First decoder block
for i in range(1, depth): # Assemble the rest decoder blocks
decoder_layers[i] = self.decoder_block(decoder_layers[i -1], encoder_layers[- i -1][1], 32*2**(i-1))
outputs = layers.Conv2D(1, (1, 1), activation='sigmoid', name = 'Output')(decoder_layers[-1]) # Output layer
self.model = models.Model(inputs=[inputs], outputs=[outputs])
def encoder_block(self, input_tensor, num_filters):
encoder = self.conv_block(input_tensor, num_filters)
encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
return encoder_pool, encoder
@staticmethod
def conv_block(input_tensor, num_filters):
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
return encoder
@staticmethod
def decoder_block(input_tensor, concat_tensor, num_filters):
decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
return decoder
我用于预测的类:
class predictor():
def __init__(self, model_path):
self.model = models.load_model(model_path, custom_objects={'wce_dice_loss': wce_dice_loss,
'dice_loss': dice_loss,
'IoU': IoU})
def predict(self, image_path, batch_size):
self.img_arr = plt.imread(image_path)[:,:,0]
img_slicer = lip(self.img_arr)
img_slices =img_slicer.streched_img_map
msk_slices = np.zeros_like(img_slices)
index = 0
while True:
msk_slices[index : index + batch_size] = self.model.predict(img_slices[index : index + batch_size])
index += batch_size
if index >= len(img_slices): break
self.mask_arr = img_slicer.mask_gluer(msk_slices)
def save(self, save_path):
self.mask_arr.imsave(save_path, cmap = 'gray')
def show(self):
plt.imshow(self.mask_arr, cmap = 'gray')
model.summarry(),用于深度为2的模型:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input (InputLayer) [(None, 512, 512, 1) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 512, 512, 32) 320 Input[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 32) 128 conv2d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 512, 512, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 512, 512, 32) 9248 activation[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512, 512, 32) 128 conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 512, 512, 32) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 256, 256, 32) 0 activation_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 256, 256, 64) 18496 max_pooling2d[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 64) 256 conv2d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 256, 256, 64) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 256, 256, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 256, 256, 64) 256 conv2d_3[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 256, 256, 64) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 128, 128, 64) 0 activation_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 128, 128, 128 73856 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 128 512 conv2d_4[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 128, 128, 128 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 128, 128, 128 147584 activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 128, 128, 128 512 conv2d_5[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 128, 128, 128 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 256, 256, 64) 32832 activation_5[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 256, 256, 128 0 activation_3[0][0]
conv2d_transpose[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 256, 256, 128 512 concatenate[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 256, 256, 128 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 256, 256, 64) 73792 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 256, 256, 64) 256 conv2d_6[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 256, 256, 64) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 256, 256, 64) 36928 activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 256, 256, 64) 256 conv2d_7[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 256, 256, 64) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 512, 512, 32) 8224 activation_8[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 512, 512, 64) 0 activation_1[0][0]
conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 512, 512, 64) 256 concatenate_1[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 512, 512, 64) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 512, 512, 32) 18464 activation_9[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 512, 512, 32) 128 conv2d_8[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 512, 512, 32) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 512, 512, 32) 9248 activation_10[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 512, 512, 32) 128 conv2d_9[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 512, 512, 32) 0 batch_normalization_11[0][0]
__________________________________________________________________________________________________
Output (Conv2D) (None, 512, 512, 1) 33 activation_11[0][0]
==================================================================================================
Total params: 469,281
Trainable params: 467,617
Non-trainable params: 1,664