如何使模型返回概率图而不是arg_max?

时间:2019-09-11 10:01:08

标签: python tensorflow keras

我创建了一个U-Net类,通过它可以定义深度可变的U-Net。我正在训练它,以便从图像中仅分割一个类,但是问题是,当我使用训练好的模型来推断图像时,返回的数组只有一个和零,我想看看它的实际概率。类。我该如何实现?

我用来创建模型的类:

class UNet():
    def __init__(self, depth, shape):
          encoder_layers = [[0, 0] for x in range(depth)]
          decoder_layers = [0 for x in range(depth)]
          inputs = layers.Input(shape=shape, name = 'Input')                                                     # Input layer
          encoder_layers[0] = self.encoder_block(inputs, 32)                                     # First encoder block
          for i in range(1, depth):                                                              # Assemble the rest encoder blocks
              encoder_layers[i] = self.encoder_block(encoder_layers[i-1][0], 32*2**i)
          center = self.conv_block(encoder_layers[-1][0], 32*2**depth)                           # Center
          decoder_layers[0] = self.decoder_block(center, encoder_layers[-1][1], 32*2**(depth-1)) # First decoder block
          for i in range(1, depth):                                                              # Assemble the rest decoder blocks
              decoder_layers[i] = self.decoder_block(decoder_layers[i -1], encoder_layers[- i -1][1], 32*2**(i-1))
          outputs = layers.Conv2D(1, (1, 1), activation='sigmoid', name = 'Output')(decoder_layers[-1]) # Output layer
          self.model = models.Model(inputs=[inputs], outputs=[outputs])

    def encoder_block(self, input_tensor, num_filters):
          encoder = self.conv_block(input_tensor, num_filters)
          encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)

          return encoder_pool, encoder      

    @staticmethod    
    def conv_block(input_tensor, num_filters):
          encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
          encoder = layers.BatchNormalization()(encoder)
          encoder = layers.Activation('relu')(encoder)
          encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
          encoder = layers.BatchNormalization()(encoder)
          encoder = layers.Activation('relu')(encoder)
          return encoder

    @staticmethod
    def decoder_block(input_tensor, concat_tensor, num_filters):
          decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
          decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
          decoder = layers.BatchNormalization()(decoder)
          decoder = layers.Activation('relu')(decoder)
          decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
          decoder = layers.BatchNormalization()(decoder)
          decoder = layers.Activation('relu')(decoder)
          decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
          decoder = layers.BatchNormalization()(decoder)
          decoder = layers.Activation('relu')(decoder)
          return decoder

我用于预测的类:

class predictor():
    def __init__(self, model_path):
        self.model = models.load_model(model_path, custom_objects={'wce_dice_loss': wce_dice_loss,
                                                                   'dice_loss': dice_loss,
                                                                   'IoU': IoU})
    def predict(self, image_path, batch_size):
        self.img_arr = plt.imread(image_path)[:,:,0]
        img_slicer = lip(self.img_arr)
        img_slices =img_slicer.streched_img_map
        msk_slices = np.zeros_like(img_slices)
        index = 0
        while True:
            msk_slices[index : index + batch_size] = self.model.predict(img_slices[index : index + batch_size])
            index += batch_size
            if index >= len(img_slices): break
        self.mask_arr = img_slicer.mask_gluer(msk_slices)

    def save(self, save_path):
        self.mask_arr.imsave(save_path, cmap = 'gray')

    def show(self):
        plt.imshow(self.mask_arr, cmap = 'gray')

model.summarry(),用于深度为2的模型:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              [(None, 512, 512, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 32) 320         Input[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 512, 512, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 32) 9248        activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512, 512, 32) 128         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 512, 512, 32) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 256, 256, 32) 0           activation_1[0][0]               
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 64) 18496       max_pooling2d[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 64) 256         conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 256, 256, 64) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 256, 256, 64) 36928       activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 256, 256, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 256, 256, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 64) 0           activation_3[0][0]               
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 128 73856       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 128 512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 128, 128, 128 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 128 147584      activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 128, 128, 128 512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 128, 128, 128 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 256, 256, 64) 32832       activation_5[0][0]               
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 256, 256, 128 0           activation_3[0][0]               
                                                                 conv2d_transpose[0][0]           
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 256, 256, 128 512         concatenate[0][0]                
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 256, 256, 128 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 256, 256, 64) 73792       activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 256, 256, 64) 256         conv2d_6[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 256, 256, 64) 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 256, 256, 64) 36928       activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 256, 256, 64) 256         conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 256, 256, 64) 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 512, 512, 32) 8224        activation_8[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 512, 512, 64) 0           activation_1[0][0]               
                                                                 conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 512, 512, 64) 256         concatenate_1[0][0]              
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 512, 512, 64) 0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 512, 512, 32) 18464       activation_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 512, 512, 32) 128         conv2d_8[0][0]                   
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 512, 512, 32) 0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 512, 512, 32) 9248        activation_10[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 512, 512, 32) 128         conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 512, 512, 32) 0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
Output (Conv2D)                 (None, 512, 512, 1)  33          activation_11[0][0]              
==================================================================================================
Total params: 469,281
Trainable params: 467,617
Non-trainable params: 1,664

0 个答案:

没有答案