TypeError:add_weight()至少需要3个参数(给定4个参数)

时间:2017-05-23 00:34:15

标签: python tensorflow keras

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)
        # self.activation = activations.get(activation)
    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        input_dim = input_shape[-1]
        # print input_shape
        kernel_shape = [input_dim/2, input_dim/2]
        print kernel_shape
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, x):
        inputs = x
        input_shape = K.int_shape(inputs)
        print input_shape
        T = tf.reshape(inputs,[-1,2,tf.to_int32(input_shape[1]/2)])
        # P = tf.matmul(T,self.kernel)
        P = tf.matmul(T[:,1,:], self.kernel)
        G = T[:,0,:]
        op = tf.concat([P,G], axis=0)
        op = tf.reshape(op, [-1, 2, tf.to_int32(input_shape[1] / 2)])
        print op
        return op

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

请帮助我,我错过了什么?

Stacktrace如下:

Traceback (most recent call last):
File "/root/PycharmProjects/tranferNET/modelBuild.py", line 193, in <module> model = create_network([100, 100, 3])
File "/root/PycharmProjects/tranferNET/modelBuild.py", line 174, in create_network com_distribution = MyLayer((2,256))(merge_common2)
File "/root/Tensorflow/local/lib/python2.7/site-packages/keras/engine/topology.py", line 558, in __call__ self.build(input_shapes[0])
File "/root/PycharmProjects/tranferNET/ncLayer.py", line 70, in build trainable=True)
File "/root/Tensorflow/local/lib/python2.7/site-packages/keras/legacy/interfaces.py", line 88, in wrapper 
return func(*args, **kwargs)
TypeError: add_weight() takes at least 3 arguments (4 given)

1 个答案:

答案 0 :(得分:0)

这是add_weight()函数的定义:

 def add_weight(self,
                   name,
                   shape,
                   dtype=None,
                   initializer=None,
                   regularizer=None,
                   trainable=True,
                   constraint=None):
        """Adds a weight variable to the layer.
        # Arguments
            name: String, the name for the weight variable.
            shape: The shape tuple of the weight.
            dtype: The dtype of the weight.
            initializer: An Initializer instance (callable).
            regularizer: An optional Regularizer instance.
            trainable: A boolean, whether the weight should
                be trained via backprop or not (assuming
                that the layer itself is also trainable).
            constraint: An optional Constraint instance.
        # Returns
            The created weight variable.
        """
        initializer = initializers.get(initializer)
        if dtype is None:
            dtype = K.floatx()
        weight = K.variable(initializer(shape), dtype=dtype, name=name)
        if regularizer is not None:
            self.add_loss(regularizer(weight))
        if constraint is not None:
            self.constraints[weight] = constraint
        if trainable:
            self._trainable_weights.append(weight)
        else:
            self._non_trainable_weights.append(weight)
return weight

在传递kwagrs初始化程序和训练之前,您需要将(self,)名称和形状作为位置参数传递。形状本身应该在&#34; name&#34;之后传递。 as&#34; kernel_shape&#34;,not&#34; shape = kernel_shape&#34;。