我正在尝试使用V-net实施Keras。 这个过程非常简单(感谢Keras),但我遇到了一些问题。 Vnet是FCNN(完全卷积神经网络),因此理论上输入的大小可以变化,只涉及卷积。但是,在收缩/缩小/卷积路径和解除/上/解卷积路径之间跳过层(长连接)的需要使我有义务选择固定的输入大小。
然而,我尝试根据2个连接图层之间的差异动态地实现Lambda图层裁剪。
这是我的实现假设带有“channels_last”数据格式的3d输入。
def CropToConcat3D():
"""
Layer cropping bigger to smaller with concatenation.
"""
def crop_to_3D(inputs):
"""
inputs = bigger_input,smaller_input
Crop bigger input to smaller input to
have same dimension.
"""
bigger_input,smaller_input = inputs
bigger_input_size=bigger.get_shape().as_list()
smaller_input_size=smaller.get_shape().as_list()
_,bh,bw,bd,_ = bigger_input_size
_,sh,sw,sd,_ = smaller_input_size
if (bh is None) and (bw is None) and (bd is None):
cropped_to_smaller_input = bigger_input
else:
cropped_to_smaller_input = bigger_input
dh,dw,dd = bh-sh, bw-sw, bd-sd
q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
q2dw:bw-(q2dw+r2dw),
q2dd:bd-(q2dd+r2dd),:]
return K.concatenate([smaller_input,cropped_to_smaller_input],
axis=-1)
return Lambda(crop_to_3D)
如果没有必要在编译模型时处理变量大小,则在使用fit,predict或其他方法运行模型时会忽略else分支。
图层使用如下,并且应该从张量动态获取尺寸(我使用张量流后端_keras_shape应该与其他后端一起使用)。
outputs = CropToConcat3D()([bigger,smaller])
所以代码似乎正在运行但是else分支总是被忽略,是否有一组编译的属性或Keras中的层指令我没有考虑到?我还检查了声明一个固定的输入,在这种情况下,访问了else分支,但如果我给出一个不同的形状输入,输入层会抛出一个错误,因为输入不符合模型尺寸。
我已经尝试了在何时/何地获取/存储维度的代码变体,但行为始终是相同的。
谢谢大家的见解。
解:d
问题在于,否则计算图中没有考虑其他语句。使用tf.cond解决了这个问题。 对于每个前馈操作,该网必须运行一次这个指令只有一次(语句是编译的,我认为是其中一个分支) 它非常难看但有效。
def CropToConcat3D():
"""
inputs = bigger_input,smaller_input
Crop bigger input to smaller input to
have same dimension.
"""
def control_copy_crop3D(inputs):
bigger_input,smaller_input = inputs
def simple_concat_3D():
return K.concatenate([bigger_input,smaller_input], axis=-1)
def crop_to_concat_3D():
bigger_shape, smaller_shape = tf.shape(bigger_input), \
tf.shape(smaller_input)
sh,sw,sd = smaller_shape[1],smaller_shape[2],smaller_shape[3]
bh,bw,bd = bigger_shape[1],bigger_shape[2],bigger_shape[3]
dh,dw,dd = bh-sh, bw-sw, bd-sd
q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
q2dw:bw-(q2dw+r2dw),
q2dd:bd-(q2dd+r2dd),:]
return K.concatenate([smaller_input,cropped_to_smaller_input], axis=-1)
smaller_shape = tf.shape(smaller_input)
sh = smaller_shape[1]
return tf.cond(tf.Variable(sh is None,dtype=tf.bool),simple_concat_3D,
crop_to_concat_3D)
return Lambda(control_copy_crop3D)