我正在使用tensorflow和keras来构建神经网络。
我想对keras.layers.Conv2DTranspose
(Definition in Keras documentation)使用转置卷积
我使用了tutorial,并按照如下方式定义了我的网络:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2DTranspose
sess = tf.Session()
batch_size = 20
model = Sequential()
model.add(Conv2DTranspose(filters = (batch_size,1,2700, 1),kernel_size = (2700,1), activation = 'relu', input_shape = (1,1,1,1)))
出现以下错误:
ValueError Traceback (most recent call last)
<ipython-input-3-01a3b17fa36f> in <module>()
11 model = Sequential()
12
---> 13 model.add(Conv2DTranspose(filters = (batch_size,1,2700, 1),kernel_size = (2700,1), activation = 'relu', input_shape = (1,1,1,1)))
/usr/local/lib/python3.5/dist-packages/keras/models.py in add(self, layer)
465 # and create the node connecting the current layer
466 # to the input layer we just created.
--> 467 layer(x)
468
469 if len(layer._inbound_nodes[-1].output_tensors) != 1:
/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
573 # Raise exceptions in case the input is not compatible
574 # with the input_spec specified in the layer constructor.
--> 575 self.assert_input_compatibility(inputs)
576
577 # Collect input shapes to build layer.
/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py in assert_input_compatibility(self, inputs)
472 self.name + ': expected ndim=' +
473 str(spec.ndim) + ', found ndim=' +
--> 474 str(K.ndim(x)))
475 if spec.max_ndim is not None:
476 ndim = K.ndim(x)
ValueError: Input 0 is incompatible with layer conv2d_transpose_1: expected ndim=4, found ndim=5
尽管如此,我的输入的尺寸为4(input_shape = (1,1,1,1)
)
如何正确定义输入,然后添加一些图层?
答案 0 :(得分:1)
在Conv2DTranspose的documentation中,您可以看到filters
是位置参数,应该是整数。此整数指定您要在图层中使用的过滤器数量。
下一个参数是kernel_size
(也是位置参数),它指定过滤器的形状。
我认为您正在寻找的是:
model.add(Conv2DTranspose(n_filt, (2700, 1), activation='relu' input_shape = (1, 1, 1,)))
其中n_filt
是图层中转置卷积滤波器的数量。
注意:
不要在input_shape
参数中提供批处理尺寸,或改用batch_input_shape
。
编辑以澄清:如果数据批的维度为(batch_size, dim1, dim2, dim3)
,则应传递input_shape = (dim1, dim2, dim3, )
(请注意最后一个空格之前的逗号)或batch_input_shape = (batch_size, dim1, dim2, dim3)
。我建议使用input_shape
,因为您对要使用的批处理大小没有任何限制。
不能使用关键字指定位置参数,即,如果filters = ...
是位置参数,则不要在函数调用中使用filters
。