可训练矩阵乘法层

时间:2020-08-20 12:47:18

标签: python tensorflow keras keras-layer tf.keras

我正在尝试在TensorFlow中构建一个(自定义的)可训练的矩阵乘法层,但是事情没有解决……更准确地说,我的模型应该是这样的:

x -> A(x) x

其中A(x)是前馈网络,其值在n x n矩阵中(因此取决于输入x),而A(x)是通过矢量乘法的矩阵。

这是我编写的代码:

class custom_layer(tf.keras.layers.Layer):
    
    def __init__(self, units=16, input_dim=32):
        super(custom_layer, self).__init__()
        self.units = units
    
    def build(self, input_shape):
        self.Tw1 = self.add_weight(name='Weights_1 ',
                                    shape=(input_shape[-1], input_shape[-1]),
                                    initializer='GlorotUniform',
                                    trainable=True)
        
        self.Tw2 = self.add_weight(name='Weights_2 ',
                                    shape=(input_shape[-1], (self.units)**2),
                                    initializer='GlorotUniform',
                                    trainable=True)
        
        self.Tb = self.add_weight(name='basies',
                                    shape=(input_shape[-1],),
                                    initializer='GlorotUniform',#Previously 'ones'
                                    trainable=True)

        
    def call(self, input):
        # Build Vector-Valued Feed-Forward Network
        ffNN = tf.matmul(input, self.Tw1) + self.Tb
        ffNN = tf.nn.relu(ffNN)
        ffNN = tf.matmul(ffNN, self.Tw2) 
    
        
        # Map to Matrix
        ffNN = tf.reshape(ffNN, [self.units,self.units])

        # Multiply Matrix-Valued function with input data
        x_out = tf.matmul(ffNN,input)
        
        # Return Output
        return x_out

现在,我构建模型:

input_layer = tf.keras.Input(shape=[2])
output_layer  = custom_layer(2)(input_layer)
model = tf.keras.Model(inputs=[input_layer], outputs=[output_layer])

# Compile Model
#----------------#
# Define Optimizer
optimizer_on = tf.keras.optimizers.SGD(learning_rate=10**(-1))
# Compile
model.compile(loss = 'mse',
                optimizer = optimizer_on,
                metrics = ['mse'])

# Fit Model
#----------------#
model.fit(data_x, data_y, epochs=(10**1), verbose=0)

,然后我收到此错误消息:

InvalidArgumentError:  Input to reshape is a tensor with 128 values, but the requested shape has 4
     [[node model_62/reconfiguration_unit_70/Reshape (defined at <ipython-input-176-0b494fa3fc75>:46) ]] [Op:__inference_distributed_function_175181]

Errors may have originated from an input operation.
Input Source operations connected to node model_62/reconfiguration_unit_70/Reshape:
 model_62/reconfiguration_unit_70/MatMul_1 (defined at <ipython-input-176-0b494fa3fc75>:41)

Function call stack:
distributed_function

想法: 网络规模似乎有问题,但是我无法弄清楚该如何/如何修复...

0 个答案:

没有答案