我正在构建用于二进制图像图像分割的u网络。我正在使用Tensorflow的tf.nn API。我的输入图像的尺寸为(256,256,3),输出的二进制图像的尺寸为(256,256,1)。 U net模型的输出必须为(1,256,256,1),但输出形状结果为(7,256,256,3)。对于卷积内核,我使用Tensorflow的截断法线初始构造函数,每个数据类型为float32。我是否在代码中的某个地方创建了多个输出层
def get_filter(shape,na):
w =tf.get_variable(name=na,shape=shape,dtype='float32',initializer=tf.truncated_normal_initializer(dtype='float32'))
return w
def unet(inp):
#f1 = get_filter(shape=[3,3,3,16])
lay_16_1 = tf.nn.conv2d(inp,filter=get_filter(shape=[3,3,3,16],na='w_1'),strides=[1,1,1,1],padding='SAME',name='conv_16_1')
lay_16_2 = tf.nn.relu(lay_16_1,name='re_16_1')
lay_16_3 = tf.layers.batch_normalization(lay_16_2,axis=-1,name='bn_16')
lay_16_4 = tf.nn.conv2d(lay_16_3,filter=get_filter([3,3,16,16],na='w_2'),strides=[1,1,1,1],padding='SAME',name='conv_16_2')
lay_16_5 = tf.nn.relu(lay_16_4,name='re_16_2')
lay_p1 = tf.nn.max_pool(lay_16_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_1')
lay_32_1 = tf.nn.conv2d(lay_p1,filter=get_filter([3,3,16,32],na='w_3'),strides=[1,1,1,1],padding='SAME',name='conv_32_1')
lay_32_2 = tf.nn.relu(lay_32_1,name='re_32_1')
lay_32_3 = tf.layers.batch_normalization(lay_32_2,axis=-1,name='bn_32')
lay_32_4 = tf.nn.conv2d(lay_32_3,filter=get_filter([3,3,32,32],na='w_4'),strides=[1,1,1,1],padding='SAME',name='conv_32_2')
lay_32_5 = tf.nn.relu(lay_32_4,name='re_32_2')
lay_p2 = tf.nn.max_pool(lay_32_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_2')
lay_64_1 = tf.nn.conv2d(lay_p2,filter=get_filter([3,3,32,64],na='w_5'),strides=[1,1,1,1],padding='SAME',name='conv_64_1')
lay_64_2 = tf.nn.relu(lay_64_1,name='re_64_1')
lay_64_3 = tf.layers.batch_normalization(lay_64_2,axis=-1,name='bn_64')
lay_64_4 = tf.nn.conv2d(lay_64_3,filter=get_filter([3,3,64,64],na='w_6'),strides=[1,1,1,1],padding='SAME',name='conv_64_2')
lay_64_5 = tf.nn.relu(lay_64_4,name='re_64_2')
lay_p3 = tf.nn.max_pool(lay_64_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_3')
lay_128_1 = tf.nn.conv2d(lay_p3,filter=get_filter([3,3,64,128],na='w_7'),strides=[1,1,1,1],padding='SAME',name='conv_128_1')
lay_128_2 = tf.nn.relu(lay_128_1,name='re_128_1')
lay_128_3 = tf.layers.batch_normalization(lay_128_2,axis=-1,name='bn_128')
lay_128_4 = tf.nn.conv2d(lay_128_3,filter=get_filter([3,3,128,128],na='w_8'),strides=[1,1,1,1],padding='SAME',name='conv_128_2')
lay_128_5 = tf.nn.relu(lay_128_4,name='re_128_2')
lay_p4 = tf.nn.max_pool(lay_128_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_4')
lay_256_1 = tf.nn.conv2d(lay_p4,filter=get_filter([3,3,128,256],na='w_9'),strides=[1,1,1,1],padding='SAME',name='conv_256_1')
lay_256_2 = tf.nn.relu(lay_256_1,name='re_256_1')
lay_256_3 = tf.layers.batch_normalization(lay_256_2,axis=-1,name='bn_256')
lay_256_4 = tf.nn.conv2d(lay_256_3,filter=get_filter([3,3,256,256],na='w_10'),strides=[1,1,1,1],padding='SAME',name='conv_256_2')
lay_256_5 = tf.nn.relu(lay_256_4,name='re_256_2')
lay_p5 = tf.nn.max_pool(lay_256_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_5')
lay_512_1 = tf.nn.conv2d(lay_p5,filter=get_filter([3,3,256,512],na='w_11'),strides=[1,1,1,1],padding='SAME',name='conv_512_1')
lay_512_2 = tf.nn.relu(lay_512_1,name='re_512_1')
lay_512_3 = tf.layers.batch_normalization(lay_512_2,axis=-1,name='bn_512')
lay_512_4 = tf.nn.conv2d(lay_512_3,filter=get_filter([3,3,512,512],na='w_12'),strides=[1,1,1,1],padding='SAME',name='conv_512_2')
lay_512_5 = tf.nn.relu(lay_512_4,name='re_512_2')
lay_p6 = tf.nn.max_pool(lay_512_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_6')
lay_1024_1 = tf.nn.conv2d(lay_p6,filter=get_filter([3,3,512,1024],na='w_13'),strides=[1,1,1,1],padding='SAME',name='conv_1024_1')
lay_1024_2 = tf.nn.relu(lay_1024_1,name='re_1024_1')
lay_1024_3 = tf.layers.batch_normalization(lay_1024_2,axis=-1,name='bn_1024')
lay_1024_4 = tf.nn.conv2d(lay_1024_3,filter=get_filter([3,3,1024,1024],na='w_14'),strides=[1,1,1,1],padding='SAME',name='conv_1024_2')
lay_1024_5 = tf.nn.relu(lay_1024_4,name='re_1024_2')
#lay_p7 = tf.nn.max_pool(lay_1024,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME',name='pool_7')
up_512 = tf.image.resize_images(images=lay_1024_5,size=[8,8],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_512_1 = tf.nn.conv2d(up_512,filter=get_filter([3,3,1024,512],na='w_15'),strides=[1,1,1,1],padding='SAME',name='mer_512_1')
con_512_2 = tf.nn.relu(con_512_1,name='rel_512_1')
mer_512 = tf.concat([lay_512_5,con_512_2],axis=0,name='mer_512_2')
con_512_3 = tf.nn.conv2d(mer_512,filter=get_filter([3,3,512,512],na='w_16'),strides=[1,1,1,1],padding='SAME',name='mer_512_3')
con_512_4 = tf.nn.relu(con_512_3,name='rel_512_2')
con_512_5 = tf.layers.batch_normalization(con_512_4,axis=-1,name='mer_bn_512')
con_512_6 = tf.nn.conv2d(con_512_5,filter=get_filter([3,3,512,512],na='w_17'),strides=[1,1,1,1],padding='SAME',name='mer_512_4')
con_512_7 = tf.nn.relu(con_512_6,name='rel_512_3')
up_256 = tf.image.resize_images(images=con_512_7,size=[16,16],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_256_1 = tf.nn.conv2d(up_256,filter=get_filter([3,3,512,256],na='w_18'),strides=[1,1,1,1],padding='SAME',name='mer_256_1')
con_256_2 = tf.nn.relu(con_256_1,name='rel_256_1')
mer_256 = tf.concat([lay_256_5,con_256_2],axis=0,name='mer_256_2')
con_256_3 = tf.nn.conv2d(mer_256,filter=get_filter([3,3,256,256],na='w_19'),strides=[1,1,1,1],padding='SAME',name='mer_256_3')
con_256_4 = tf.nn.relu(con_256_3,name='rel_256_2')
con_256_5 = tf.layers.batch_normalization(con_256_4,axis=-1,name='mer_bn_256')
con_256_6 = tf.nn.conv2d(con_256_5,filter=get_filter([3,3,256,256],na='w_20'),strides=[1,1,1,1],padding='SAME',name='mer_256_4')
con_256_7 = tf.nn.relu(con_256_6,name='rel_256_3')
up_128 = tf.image.resize_images(images=con_256_7,size=[32,32],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_128_1 = tf.nn.conv2d(up_128,filter=get_filter([3,3,256,128],na='w_21'),strides=[1,1,1,1],padding='SAME',name='mer_128_1')
con_128_2 = tf.nn.relu(con_128_1,name='rel_128_1')
mer_128 = tf.concat([lay_128_5,con_128_2],axis=0,name='mer_128_2')
con_128_3 = tf.nn.conv2d(mer_128,filter=get_filter([3,3,128,128],na='w_22'),strides=[1,1,1,1],padding='SAME',name='mer_128_3')
con_128_4 = tf.nn.relu(con_128_3,name='rel_128_2')
con_128_5 = tf.layers.batch_normalization(con_128_4,axis=-1,name='mer_bn_128')
con_128_6 = tf.nn.conv2d(con_128_5,filter=get_filter([3,3,128,128],na='w_23'),strides=[1,1,1,1],padding='SAME',name='mer_128_4')
con_128_7 = tf.nn.relu(con_128_6,name='rel_128_3')
up_64 = tf.image.resize_images(images=con_128_7,size=[64,64],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_64_1 = tf.nn.conv2d(up_64,filter=get_filter([3,3,128,64],na='w_24'),strides=[1,1,1,1],padding='SAME',name='mer_64_1')
con_64_2 = tf.nn.relu(con_64_1,name='rel_64_1')
mer_64 = tf.concat([lay_64_5,con_64_2],axis=0,name='mer_64_2')
con_64_3 = tf.nn.conv2d(mer_64,filter=get_filter([3,3,64,64],na='w_25'),strides=[1,1,1,1],padding='SAME',name='mer_64_3')
con_64_4 = tf.nn.relu(con_64_3,name='rel_64_2')
con_64_5 = tf.layers.batch_normalization(con_64_4,axis=-1,name='mer_bn_64')
con_64_6 = tf.nn.conv2d(con_64_5,filter=get_filter([3,3,64,64],na='w_26'),strides=[1,1,1,1],padding='SAME',name='mer_64_4')
con_64_7 = tf.nn.relu(con_64_6,name='rel_64_3')
up_32 = tf.image.resize_images(images=con_64_7,size=[128,128],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_32_1 = tf.nn.conv2d(up_32,filter=get_filter([3,3,64,32],na='w_27'),strides=[1,1,1,1],padding='SAME',name='mer_32_1')
con_32_2 = tf.nn.relu(con_32_1,name='rel_32_1')
mer_32 = tf.concat([lay_32_5,con_32_2],axis=0,name='mer_32_2')
con_32_3 = tf.nn.conv2d(mer_32,filter=get_filter([3,3,32,32],na='w_28'),strides=[1,1,1,1],padding='SAME',name='mer_32_3')
con_32_4 = tf.nn.relu(con_32_3,name='rel_32_2')
con_32_5 = tf.layers.batch_normalization(con_32_4,axis=-1,name='mer_bn_32')
con_32_6 = tf.nn.conv2d(con_32_5,filter=get_filter([3,3,32,32],na='w_29'),strides=[1,1,1,1],padding='SAME',name='mer_32_4')
con_32_7 = tf.nn.relu(con_32_6,name='rel_32_3')
up_16 = tf.image.resize_images(images=con_32_7,size=[256,256],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
con_16_1 = tf.nn.conv2d(up_16,filter=get_filter([3,3,32,16],na='w_30'),strides=[1,1,1,1],padding='SAME',name='mer_16_1')
con_16_2 = tf.nn.relu(con_16_1,name='rel_16_1')
mer_16 = tf.concat([lay_16_5,con_16_2],axis=0,name='mer_16_2')
con_16_3 = tf.nn.conv2d(mer_16,filter=get_filter([3,3,16,16],na='w_31'),strides=[1,1,1,1],padding='SAME',name='mer_16_3')
con_16_4 = tf.nn.relu(con_16_3,name='rel_16_2')
con_16_5 = tf.layers.batch_normalization(con_16_4,axis=-1,name='mer_bn_16')
con_16_6 = tf.nn.conv2d(con_16_5,filter=get_filter([3,3,16,16],na='w_32'),strides=[1,1,1,1],padding='SAME',name='mer_16_4')
con_16_7 = tf.nn.relu(con_16_6,name='rel_16_3')
fin_img = tf.nn.conv2d(con_16_7,filter=get_filter([1,1,16,1],na='w_33'),strides=[1,1,1,1],padding='SAME',name='final_image')
#fin_img = tf.nn.sigmoid(fin_img)
return fin_img