如何使用具有两个None维的Reshape keras层?

时间:2019-03-04 12:28:28

标签: python tensorflow keras deep-learning

我有一个keras 3D / 2D模型。在此模型中,3D层的形状为[None,None,4,32]。我想将其重塑为[None,None,128]。但是,如果我只是执行以下操作:

  

reshaped_layer = Reshape((-1,128))(my_layer)

my_layer的形状为[None,128],因此以后我无法再应用任何2D卷积,例如:

  

conv_x = Conv2D(16,(1,1))(reshaped_layer)

我尝试使用tf.shape(my_layer)和tf.reshape,但是由于tf.reshape不是Keras层,所以我无法编译模型。

请澄清一下,我最后使用的是频道;这不是tf.keras,这只是Keras。在这里,我发送了重塑功能的调试:keras中的重塑

这是我现在正在按照anna-krogager的建议进行的操作:

def reshape(x):
    x_shape = K.shape(x)
    new_x_shape = K.concatenate([x_shape[:-2], [x_shape[-2] * x_shape[-1]]])
    return K.reshape(x, new_x_shape)

reshaped = Lambda(lambda x: reshape(x))(x)
reshaped.set_shape([None,None, None, 128])
conv_x = Conv2D(16, (1,1))(reshaped)

我收到以下错误:ValueError:应该定义输入的通道尺寸。没有找到

Debug, keras reshape

3 个答案:

答案 0 :(得分:1)

您可以使用K.shape来获取输入的形状(作为张量),并将重塑包装在Lambda层中,如下所示:

def reshape(x):
    x_shape = K.shape(x)
    new_x_shape = K.concatenate([x_shape[:-2], [x_shape[-2] * x_shape[-1]]])
    return K.reshape(x, new_x_shape)

reshaped = Lambda(lambda x: reshape(x))(x)
reshaped.set_shape([None, None, None, a * b]) # when x is of shape (None, None, a, b)

这会将形状为(None, None, a, b)的张量重塑为(None, None, a * b)

答案 1 :(得分:0)

深入到base_layer.py中,我发现改形为:

  

tf.Tensor'lambda_1 / Reshape:0'shape =(?,?,?,128)dtype = float32。

但是,即使在set_shape之后,其属性“ _keras_shape”也是(无,无,无,无)。因此,解决方案是设置以下属性:

def reshape(x):
    x_shape = K.shape(x)
    new_x_shape = K.concatenate([x_shape[:-2], [x_shape[-2] * x_shape[-1]]])
    return K.reshape(x, new_x_shape)

reshaped = Lambda(lambda x: reshape(x))(x)
reshaped.set_shape([None, None, None, 128])
reshaped.__setattr__("_keras_shape", (None, None, None, 128))
conv_x = Conv2D(16, (1,1))(reshaped)

答案 2 :(得分:0)

由于您正在重塑(4,32)而不损失尺寸,因此最好的是(128,1)或(1,128)。因此,您可以执行以下操作:

# original has shape [None, None, None, 4, 32] (including batch)

reshaped_layer = Reshape((-1, 128))(original) # shape is [None, None, 128]
conv_layer = Conv2D(16, (1,1))(K.expand_dims(reshaped_layer, axis=-2)) # shape is [None, None, 1, 16]