渠道维度找到'无',但是通过tensor.get_shape()已定义

时间:2018-05-28 06:38:03

标签: python tensorflow keras tensor

当我试图将一些相同形状的张量连接在一起以通过某些操作(一些完全连接的层等)处理时,我遇到了问题 我将它们连接起来如下:

x_con = KL.concatenate([x1, x2, x3, x4], axis=-1,name='P3_concatenate')
x=Squeeze_excitation(x_con, origin_x=x, out_dim=256 * 4, ratio=16, layer_name='x_con')

和Squeeze_exicitation是我用张量定义处理的函数:

def Squeeze_excitation(input_x, origin_x, out_dim, ratio, layer_name):
    print("input shape:",input_x.get_shape().as_list())
    squeeze = KL.GlobalAveragePooling2D(name=layer_name+'_Squeeze_Layer')(input_x)

    excitation = KL.Dense(units=out_dim // ratio, name=layer_name + '_fully_connected1')(squeeze)

    excitation = KL.Activation('relu',name=layer_name+'_relu')(excitation)
    excitation = KL.Dense(units=out_dim, name=layer_name + '_fully_connected2')(excitation)

    excitation = KL.Activation('sigmoid', name=layer_name+'_sigmoid')(excitation)

    excitation = KL.Reshape((1, 1, out_dim))(excitation)
    print("exicitation shape:", excitation.get_shape().as_list())

    scale = KL.multiply([input_x,excitation], name=layer_name+'_multiply')
    print("scale shape:",scale.get_shape().as_list())

    index = K.constant(value=out_dim//4,dtype=tf.int32)
    scale = KL.add([scale[:, :, :, 0:index], scale[:, :, :, index:2 * index],
                    scale[:, :, :, 2 * index:3 * index], scale[:, :, :, 3 * index:]],name=layer_name+'_Add_n')
    print("scale shape:", scale.get_shape().as_list())

    return scale

在张量处理之后,我尝试在这个张量中添加一个卷积层:

x = KL.Conv2D(256, (3, 3), padding="SAME")(x)

它引发了一个错误:

ValueError: The channel dimension of the inputs should be defined. Found `None`.

要使用它进行处理时检查张量的大小,我添加了一些输出以保持其尺寸,这里是函数Squueze_excitation中的输出:

input shape: [None, 32, 32, 1024]
exicitation shape: [None, 1, 1, 1024]
scale shape: [None, 32, 32, 1024]
scale shape: [None, 32, 32, 256]

我不知道如何处理这个问题,有人可以帮助我离开这里吗? P.S.如果在这里有任何不明确的问题描述,请随时提出。

1 个答案:

答案 0 :(得分:0)

我修复了这个问题,发现它是由变量索引引起的,这是一个keras.constant,我把它改成:

scale = KL.Lambda(lambda inputs: KL.Add()([inputs[:, :, :, 0:256], inputs[:, :, :, 256:512],
                    inputs[:, :, :, 512:768], inputs[:, :, :, 768:]]))(scale)

这意味着我用Keras中的Lambda Layer包装它。我在这里删除了变量索引。