如何在Keras中将填充层的零层置零?

时间:2018-11-19 01:01:34

标签: python tensorflow keras deep-learning

我可以看到Keras [doc]中有ZeroPadding1D,但它需要3D tensor with shape (batch, axis_to_pad, features),但是如何将形状为(batch, features)的右侧的密集层输出置零?

类似的东西:

x = Dense(64, activation='linear')(x)
x = ZeroPadding1D(padding=(0,64))(x)

更新:

我尝试实现自定义层:

class CustomZeroPadding1D(Layer):
    def __init__(self, **kwargs):
        super(CustomZeroPadding1D, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CustomZeroPadding1D, self).build(input_shape)

    def call(self, x):
        res = concatenate([x, K.zeros_like(x)], axis=-1)
        return res

    def compute_output_shape(self, input_shape):
        print('-'*60)
        print('input_shape', input_shape)
        output_shape = (input_shape[0], input_shape[1]*2)
        print('output_shape', output_shape)
        print('-' * 60)
        return output_shape

model.summary()中的形状看起来不错,但是当我进行训练时,它以Incompatible shapes: [32,312] vs. [0,312]失败,其中batch_size = 32和n_features = 312,因此由于某种原因,似乎自定义层设置了批次大小= 0? / p>

1 个答案:

答案 0 :(得分:0)

您可以使用tensorflow.reshape()来做到这一点:

x = Dense(64, activation='linear')(x)
x = Reshape([-1, 1])(x)
x = ZeroPadding1D(padding=(0,64))(x)

一个简单的示例:

from keras.layers import Dense, ZeroPadding1D, Reshape
import tensorflow as tf

x = tf.constant([ [1.,1.,1.,1.], [1.,1.,1.,1.], [1.,1.,1.,1.] ])

x = Dense(64, activation='linear')(x)
print(x.shape)
x = Reshape([-1, 1])(x)
print(x.shape)
x = ZeroPadding1D(padding=(0,10))(x)
print(x.shape)
x_ = x
x = Reshape([-1])(x)
print(x.shape)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(x_.eval()[0])
    print(x.eval()[0])

退出:

(3, 64)
(3, 64, 1)
(3, 74, 1)
(3, 74)
[[-0.4807737 ]
 [-0.11104995]
 [-0.13179883]
 [-0.34604812]
 [-0.51428366]
 [ 0.39941764]
 [-0.0830393 ]
 [-0.79970515]
 [-0.34765747]
 [ 0.6095309 ]
 [ 0.03473596]
 [-0.68571013]
 [-0.12576953]
 [ 0.08276424]
 [ 0.06647275]
 [ 0.1670239 ]
 [-0.26894042]
 [ 0.03662822]
 [ 0.24533364]
 [ 0.2816307 ]
 [-0.28530025]
 [ 0.33335078]
 [-0.52831376]
 [-0.5450369 ]
 [-0.30863497]
 [ 0.14870909]
 [ 0.4303183 ]
 [-0.11658342]
 [ 0.60449684]
 [-0.47217163]
 [-0.3101738 ]
 [-0.15606529]
 [ 0.25018048]
 [-0.34411854]
 [-0.6233641 ]
 [ 0.01476687]
 [-0.32950908]
 [ 0.40554196]
 [ 0.2916515 ]
 [-0.5265654 ]
 [-0.13000801]
 [ 0.45457274]
 [-0.32708472]
 [ 0.20291099]
 [ 0.15016158]
 [ 0.02729714]
 [ 0.33809263]
 [-0.67841053]
 [ 0.31094086]
 [-0.3722076 ]
 [ 0.31136334]
 [ 0.21413101]
 [-0.40144968]
 [ 0.37131637]
 [ 0.17351764]
 [ 0.1576828 ]
 [-0.299753  ]
 [ 0.32608157]
 [-0.15042162]
 [-0.2388339 ]
 [ 0.18553   ]
 [-0.4828058 ]
 [ 0.07377535]
 [-0.291501  ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]]

[-0.4807737  -0.11104995 -0.13179883 -0.34604812 -0.51428366  0.39941764
 -0.0830393  -0.79970515 -0.34765747  0.6095309   0.03473596 -0.68571013
 -0.12576953  0.08276424  0.06647275  0.1670239  -0.26894042  0.03662822
  0.24533364  0.2816307  -0.28530025  0.33335078 -0.52831376 -0.5450369
 -0.30863497  0.14870909  0.4303183  -0.11658342  0.60449684 -0.47217163
 -0.3101738  -0.15606529  0.25018048 -0.34411854 -0.6233641   0.01476687
 -0.32950908  0.40554196  0.2916515  -0.5265654  -0.13000801  0.45457274
 -0.32708472  0.20291099  0.15016158  0.02729714  0.33809263 -0.67841053
  0.31094086 -0.3722076   0.31136334  0.21413101 -0.40144968  0.37131637
  0.17351764  0.1576828  -0.299753    0.32608157 -0.15042162 -0.2388339
  0.18553    -0.4828058   0.07377535 -0.291501    0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.        ]