传递形状为(2,256,256,120,2)的目标数组以输出形状为(None,256,256,592,1)的输出,同时使用“ binary_crossentropy”作为损失

时间:2019-11-12 04:18:51

标签: python tensorflow keras image-segmentation

我正在用3D体积分割医学图像。图像的原始大小为(256,256,120),在DataGeneration之后,输入为(2,256,256,120,1),而输出为(2,256,256,120,2)。但是在model generation_fit上,我得到了上面的消息,如标题所示:

型号:

def build_model(inp_shape=(256, 256, 120 ,1)):
print("start building NN")

#Testing building my model  
data = Input(inp_shape) 
print(data)
flt=32


conv1 = Conv3D(flt, (3, 3, 3), activation='relu', padding='same')(data)
conv1 = Conv3D(flt, (3, 3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

conv2 = Conv3D(flt*2, (3, 3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv3D(flt*2, (3, 3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

conv3 = Conv3D(flt*4, (3, 3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv3D(flt*4, (3, 3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

conv4 = Conv3D(flt*8, (3, 3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv3D(flt*8, (3, 3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)

conv5 = Conv3D(flt*16, (3, 3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv3D(flt*16, (3, 3, 3), activation='relu', padding='same')(conv5)

up6 = concatenate([Conv3DTranspose(flt*8, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5), conv4], axis=3)
conv6 = Conv3D(flt*8, (3, 3, 3), activation='relu', padding='same')(up6)
conv6 = Conv3D(flt*8, (3, 3, 3), activation='relu', padding='same')(conv6)

up7 = concatenate([Conv3DTranspose(flt*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6), conv3], axis=3)
conv7 = Conv3D(flt*4, (3, 3, 3), activation='relu', padding='same')(up7)
conv7 = Conv3D(flt*4, (3, 3, 3), activation='relu', padding='same')(conv7)

up8 = concatenate([Conv3DTranspose(flt*2, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv7), conv2], axis=3)
conv8 = Conv3D(flt*2, (3, 3, 3), activation='relu', padding='same')(up8)
conv8 = Conv3D(flt*2, (3, 3, 3), activation='relu', padding='same')(conv8)

up9 = concatenate([Conv3DTranspose(flt, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv8), conv1], axis=3)
conv9 = Conv3D(flt, (3, 3, 3), activation='relu', padding='same')(up9)
conv9 = Conv3D(flt, (3, 3, 3), activation='relu', padding='same')(conv9)

conv10 = Conv3D(1, (1,1,1), activation='sigmoid')(conv9)

model = Model(inputs=[data], outputs=[conv10])

return model

火车:

if __name__ == '__main__':

im_path = "/root/data/demo/Demo/images/"
label_path = "/root/data/demo/Demo/labels/"

# Parameters
params = {'dim': (256,256,120),
         'batch_size': 2,
         'im_path': im_path,
         'label_path': label_path,
         'n_classes': 2,
         'augment': False,
         'n_channels':1,
         'shuffle': True
         }

partition = {}


images = glob.glob(os.path.join(im_path, "*.nii.gz"))
images_IDs = [name.split("/")[-1] for name in images]

partition['train'] = images_IDs

# Load training data
training_generator = DataGenerator(partition['train'], **params) 
 # Design model
model = Sequential()



# Build model

inp_shape = y[0].shape  
model = build_model() 
model.compile(optimizer=Adam(lr=1e-5), loss='binary_crossentropy', metrics=['binary_accuracy'])


##########################################################################################
checkpointer = ModelCheckpoint('model.{epoch:03d}.hdf5', save_freq=5)


cb_1=keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=2, verbose=1, mode='auto')
cb_2=keras.callbacks.ModelCheckpoint(filepath="model.{epoch:03d}.hdf5", monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)



results = model.fit_generator(training_generator, steps_per_epoch= 4,
                                      epochs=2,
                                      verbose=1, 
                                      use_multiprocessing=True,
                                      workers=4,
                                      validation_data=None, validation_steps=None,
                                      max_queue_size=8)

0 个答案:

没有答案