基于输入的裁剪keras层

时间:2018-03-13 13:08:51

标签: python-3.x tensorflow neural-network deep-learning keras

我正在尝试使用V-net实施Keras。 这个过程非常简单(感谢Keras),但我遇到了一些问题。 Vnet是FCNN(完全卷积神经网络),因此理论上输入的大小可以变化,只涉及卷积。但是,在收缩/缩小/卷积路径和解除/上/解卷积路径之间跳过层(长连接)的需要使我有义务选择固定的输入大小。

然而,我尝试根据2个连接图层之间的差异动态地实现Lambda图层裁剪。

这是我的实现假设带有“channels_last”数据格式的3d输入。

def CropToConcat3D():
"""
   Layer cropping bigger to smaller with concatenation.
"""

    def crop_to_3D(inputs):
    """
        inputs = bigger_input,smaller_input
        Crop bigger input to smaller input to
        have same dimension.
    """
        bigger_input,smaller_input = inputs
        bigger_input_size=bigger.get_shape().as_list()
        smaller_input_size=smaller.get_shape().as_list()
        _,bh,bw,bd,_ = bigger_input_size
        _,sh,sw,sd,_ = smaller_input_size
        if (bh is None) and (bw is None) and (bd is None):
            cropped_to_smaller_input = bigger_input
        else:
            cropped_to_smaller_input = bigger_input
            dh,dw,dd = bh-sh, bw-sw, bd-sd
            q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
            cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
                                                      q2dw:bw-(q2dw+r2dw),
                                                      q2dd:bd-(q2dd+r2dd),:]
        return K.concatenate([smaller_input,cropped_to_smaller_input],
                           axis=-1)

    return Lambda(crop_to_3D)     

如果没有必要在编译模型时处理变量大小,则在使用fit,predict或其他方法运行模型时会忽略else分支。

图层使用如下,并且应该从张量动态获取尺寸(我使用张量流后端_keras_shape应该与其他后端一起使用)。

outputs = CropToConcat3D()([bigger,smaller])

所以代码似乎正在运行但是else分支总是被忽略,是否有一组编译的属性或Keras中的层指令我没有考虑到?我还检查了声明一个固定的输入,在这种情况下,访问了else分支,但如果我给出一个不同的形状输入,输入层会抛出一个错误,因为输入不符合模型尺寸。

我已经尝试了在何时/何地获取/存储维度的代码变体,但行为始终是相同的。

谢谢大家的见解。

解:d

问题在于,否则计算图中没有考虑其他语句。使用tf.cond解决了这个问题。 对于每个前馈操作,该网必须运行一次这个指令只有一次(语句是编译的,我认为是其中一个分支) 它非常难看但有效。

def CropToConcat3D():
"""
    inputs = bigger_input,smaller_input
    Crop bigger input to smaller input to
    have same dimension.
"""

def control_copy_crop3D(inputs):
    bigger_input,smaller_input = inputs
    def simple_concat_3D():
        return K.concatenate([bigger_input,smaller_input], axis=-1)
    def crop_to_concat_3D():
        bigger_shape, smaller_shape = tf.shape(bigger_input), \
                                      tf.shape(smaller_input)
        sh,sw,sd = smaller_shape[1],smaller_shape[2],smaller_shape[3]
        bh,bw,bd = bigger_shape[1],bigger_shape[2],bigger_shape[3]
        dh,dw,dd = bh-sh, bw-sw, bd-sd
        q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
        cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
                                                  q2dw:bw-(q2dw+r2dw),
                                                  q2dd:bd-(q2dd+r2dd),:]
        return K.concatenate([smaller_input,cropped_to_smaller_input], axis=-1)

    smaller_shape = tf.shape(smaller_input)
    sh = smaller_shape[1]
    return tf.cond(tf.Variable(sh is None,dtype=tf.bool),simple_concat_3D,
                   crop_to_concat_3D)

return Lambda(control_copy_crop3D)      

0 个答案:

没有答案