我可以看到Keras [doc]中有ZeroPadding1D
,但它需要3D tensor with shape (batch, axis_to_pad, features)
,但是如何将形状为(batch, features)
的右侧的密集层输出置零?
类似的东西:
x = Dense(64, activation='linear')(x)
x = ZeroPadding1D(padding=(0,64))(x)
更新:
我尝试实现自定义层:
class CustomZeroPadding1D(Layer):
def __init__(self, **kwargs):
super(CustomZeroPadding1D, self).__init__(**kwargs)
def build(self, input_shape):
super(CustomZeroPadding1D, self).build(input_shape)
def call(self, x):
res = concatenate([x, K.zeros_like(x)], axis=-1)
return res
def compute_output_shape(self, input_shape):
print('-'*60)
print('input_shape', input_shape)
output_shape = (input_shape[0], input_shape[1]*2)
print('output_shape', output_shape)
print('-' * 60)
return output_shape
model.summary()
中的形状看起来不错,但是当我进行训练时,它以Incompatible shapes: [32,312] vs. [0,312]
失败,其中batch_size = 32和n_features = 312,因此由于某种原因,似乎自定义层设置了批次大小= 0? / p>
答案 0 :(得分:0)
您可以使用tensorflow.reshape()来做到这一点:
x = Dense(64, activation='linear')(x)
x = Reshape([-1, 1])(x)
x = ZeroPadding1D(padding=(0,64))(x)
一个简单的示例:
from keras.layers import Dense, ZeroPadding1D, Reshape
import tensorflow as tf
x = tf.constant([ [1.,1.,1.,1.], [1.,1.,1.,1.], [1.,1.,1.,1.] ])
x = Dense(64, activation='linear')(x)
print(x.shape)
x = Reshape([-1, 1])(x)
print(x.shape)
x = ZeroPadding1D(padding=(0,10))(x)
print(x.shape)
x_ = x
x = Reshape([-1])(x)
print(x.shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(x_.eval()[0])
print(x.eval()[0])
退出:
(3, 64)
(3, 64, 1)
(3, 74, 1)
(3, 74)
[[-0.4807737 ]
[-0.11104995]
[-0.13179883]
[-0.34604812]
[-0.51428366]
[ 0.39941764]
[-0.0830393 ]
[-0.79970515]
[-0.34765747]
[ 0.6095309 ]
[ 0.03473596]
[-0.68571013]
[-0.12576953]
[ 0.08276424]
[ 0.06647275]
[ 0.1670239 ]
[-0.26894042]
[ 0.03662822]
[ 0.24533364]
[ 0.2816307 ]
[-0.28530025]
[ 0.33335078]
[-0.52831376]
[-0.5450369 ]
[-0.30863497]
[ 0.14870909]
[ 0.4303183 ]
[-0.11658342]
[ 0.60449684]
[-0.47217163]
[-0.3101738 ]
[-0.15606529]
[ 0.25018048]
[-0.34411854]
[-0.6233641 ]
[ 0.01476687]
[-0.32950908]
[ 0.40554196]
[ 0.2916515 ]
[-0.5265654 ]
[-0.13000801]
[ 0.45457274]
[-0.32708472]
[ 0.20291099]
[ 0.15016158]
[ 0.02729714]
[ 0.33809263]
[-0.67841053]
[ 0.31094086]
[-0.3722076 ]
[ 0.31136334]
[ 0.21413101]
[-0.40144968]
[ 0.37131637]
[ 0.17351764]
[ 0.1576828 ]
[-0.299753 ]
[ 0.32608157]
[-0.15042162]
[-0.2388339 ]
[ 0.18553 ]
[-0.4828058 ]
[ 0.07377535]
[-0.291501 ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]
[ 0. ]]
[-0.4807737 -0.11104995 -0.13179883 -0.34604812 -0.51428366 0.39941764
-0.0830393 -0.79970515 -0.34765747 0.6095309 0.03473596 -0.68571013
-0.12576953 0.08276424 0.06647275 0.1670239 -0.26894042 0.03662822
0.24533364 0.2816307 -0.28530025 0.33335078 -0.52831376 -0.5450369
-0.30863497 0.14870909 0.4303183 -0.11658342 0.60449684 -0.47217163
-0.3101738 -0.15606529 0.25018048 -0.34411854 -0.6233641 0.01476687
-0.32950908 0.40554196 0.2916515 -0.5265654 -0.13000801 0.45457274
-0.32708472 0.20291099 0.15016158 0.02729714 0.33809263 -0.67841053
0.31094086 -0.3722076 0.31136334 0.21413101 -0.40144968 0.37131637
0.17351764 0.1576828 -0.299753 0.32608157 -0.15042162 -0.2388339
0.18553 -0.4828058 0.07377535 -0.291501 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. ]