在 Keras 自定义图层中保存图层输出

时间:2021-03-02 01:56:42

标签: python tensorflow keras shapes layer

我正在 Keras 中构建一个自定义层,该层应该记住其先前的输出。

层本身有几个嵌套层(conv2 -> norm -> relu),它们按顺序执行。 我想使用 tf.Variable 来存储我的结果,但由于批量大小可变而遇到问题。

我们的想法是:

  • 获取图层输入
  • 连接存储在变量中的先前输出(第一次运行时初始为零)
  • 通过嵌套层 conv2 -> norm -> relu 运行它
  • 将relu的输出保存到变量(tf.keras.backend.update)

如何避免批量大小可变的问题?它们阻止我在 build() 中初始化一个全为零的变量。

这是相关的代码片段:

class OnePassRecurrent(keras.layers.Layer):
    
    def __init__(self, output_size, kernel_size, **kwargs):
        self.output_size = output_size
        self.kernel_size = kernel_size
        super(OnePassRecurrent, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.conv1 = layers.Conv2D(self.output_size, self.kernel_size, padding = 'same', kernel_initializer = 'he_normal') # Notice no activation. Added after layer norm
        self.norm1 = layers.LayerNormalization()
        self.relu1 = layers.ReLU()
        
        output_shape = input_shape.as_list()
        output_shape[-1] = self.output_size

        # TODO This does not work right now
        #self.last_output = tf.Variable(tf.fill(input_shape, 0.0), validate_shape=False, trainable=False)
        #self.last_output = tf.Variable(tf.fill(tf.TensorShape(input_shape), 0.0))
                
    def call(self, inputs, **kwargs):
        lastx = self.last_output
        x = layers.Concatenate()([inputs, lastx])
            
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        
        tf.keras.backend.update(self.last_output, x)
        
        return x

感谢您的帮助!

0 个答案:

没有答案
相关问题