在Keras中使用“重塑”图层时,“ ValueError:新数组的总大小必须保持不变”

时间:2020-08-17 09:32:19

标签: python tensorflow keras

我正在尝试遵循此博客文章,他们在其中构建一个自动编码器,编码器和解码器共享权重。

https://medium.com/@lmayrandprovencher/building-an-autoencoder-with-tied-weights-in-keras-c4a559c529a2

    class DenseTranspose(keras.layers.Layer):
        def __init__(self,dense,activation=None,**kwargs):
            self.dense = dense
            self.activation = keras.activations.get(activation)
            super().__init__(**kwargs)
        def build(self, batch_input_shape):
            self.biases = self.add_weight(name='bias',shape=[self.dense.input_shape[-1]],initializer='zero')
            return super().build(batch_input_shape)
        def call(self,inputs):
            z = tf.matmul(inputs,self.dense.weights[0],transpose_b=True)
            return self.activation(z+self.biases)

    inputs = Input(shape = (28,28,1))

    dense_1 = Dense(784,activation='relu')
    dense_2 = Dense(392,activation='relu')
    dense_3 = Dense(196,activation='relu')

    #Flatting input data to feed into Dense layer.

    x = Flatten()(inputs)
    x = dense_1(x)
    x = dense_2(x)
    x = dense_3(x)
    x = DenseTranspose(dense_3,activation='relu')(x)
    x = DenseTranspose(dense_2,activation='relu')(x)
    x = DenseTranspose(dense_1,activation='sigmoid')(x)

    outputs = Reshape([28,28])(x)

但是在尝试编译代码时会说

Traceback (most recent call last):
File "<string>", line 1, in <module>
File ".\Programs\Python\Python37\lib\site-packages\keras\engine\base_layer.py", line 474, in __call__
output_shape = self.compute_output_shape(input_shape)
File ".\Programs\Python\Python37\lib\site-packages\keras\layers\core.py", line 398, in 
compute_output_shape
input_shape[1:], self.target_shape)
File ".\Programs\Python\Python37\lib\site-packages\keras\layers\core.py", line 386, in _fix_unknown_dimension
raise ValueError(msg)
ValueError: total size of new array must be unchanged

当在行outputs = Rehsape([28,28])(x)之前打印x的值时,该值如下。

<tf.Tensor 'dense_transpose_3/Sigmoid:0' shape=(?, 784) dtype=float32>

既然784 = 28 * 28,为什么我不能像在博客中那样将其重塑为(28,28)?

1 个答案:

答案 0 :(得分:0)

我不完全理解这个问题,但是我建议使用Reshape层切换Lambda

outputs = Lambda(lambda x: K.reshape(x, (28, 28)))(x)

请参阅here