Keras自定义图层(另一个自定义图层为超类)不接受多个输入

时间:2019-07-31 19:08:51

标签: python keras deep-learning keras-layer

我正在尝试实现Keras自定义层,其超类是另一个自定义层。超级类只需要一个输入就可以了。但是应该接受两个输入的子类不接受多个输入。

我试图将超类与Layer配对。但这是一样的。我认为在构建步骤时存在一些问题。

class SpadeLayer(ConvSN2D, Layer):

    def __init__(self, filters, kernel_size, **kwargs):
        super(SpadeLayer, self).__init__(filters, kernel_size, **kwargs)

    def build(self, input_shape):
        self.bn = BatchNormalization(center=False, scale=False)
        self.conv0 = ConvSN2D(128, kernel_size , strides=1, padding='same',kernel_initializer='glorot_uniform')
        self.conv1 = ConvSN2D(filters, kernel_size , strides=1, padding='same',kernel_initializer='glorot_uniform')
        self.conv2 = ConvSN2D(filters, kernel_size , strides=1, padding='same',kernel_initializer='glorot_uniform')
        super(SpadeLayer, self).build(input_shape)

    def call(self, inputs):
        _, f_h, f_w, _ = inputs[0].get_shape().as_list()
        segmap_down = K.tf.image.resize_images(inputs[1], (f_h, f_w))

        init_conv = self.conv0()(segmap_down)
        gamma = self.conv1()(init_conv)
        beta = self.conv1()(init_conv)
        return (self.bn(features) * gamma) + beta

layer = SpadeLayer(3, 3)
input1 = Input(shape=(16,16,32))
input2 = Input(shape=(16,16,1))
output = layer([input1, input2])
model = tf.keras.Model(inputs=[input], outputs=[output])
model.summary()

我遇到以下错误:

ValueError                                Traceback (most recent call last)
<ipython-input-39-45f88bca3df3> in <module>
      2 input1 = Input(shape=(16,16,32))
      3 input2 = Input(shape=(16,16,1))
----> 4 output = layer([input1, input2])
      5 model = tf.keras.Model(inputs=[input], outputs=[output])
      6 model.summary()

C:\Anaconda3\envs\env_tf\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
    412                 # Raise exceptions in case the input is not compatible
    413                 # with the input_spec specified in the layer constructor.
--> 414                 self.assert_input_compatibility(inputs)
    415 
    416                 # Collect input shapes to build layer.

C:\Anaconda3\envs\env_tf\lib\site-packages\keras\engine\base_layer.py in assert_input_compatibility(self, inputs)
    297                              'but it received ' + str(len(inputs)) +
    298                              ' input tensors. Input received: ' +
--> 299                              str(inputs))
    300         for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    301             if spec is None:

ValueError: Layer spade_layer_11 expects 1 inputs, but it received 2 input tensors. Input received: [<tf.Tensor 'input_21:0' shape=(?, 16, 16, 32) dtype=float32>, <tf.Tensor 'input_22:0' shape=(?, 16, 16, 1) dtype=float32>]

我只是不知道该怎么办才能纠正它。我检查了很多资源,但没有找到类似的东西。

0 个答案:

没有答案