我正在 Keras 中构建一个自定义层,该层应该记住其先前的输出。
层本身有几个嵌套层(conv2 -> norm -> relu),它们按顺序执行。 我想使用 tf.Variable 来存储我的结果,但由于批量大小可变而遇到问题。
我们的想法是:
如何避免批量大小可变的问题?它们阻止我在 build() 中初始化一个全为零的变量。
这是相关的代码片段:
class OnePassRecurrent(keras.layers.Layer):
def __init__(self, output_size, kernel_size, **kwargs):
self.output_size = output_size
self.kernel_size = kernel_size
super(OnePassRecurrent, self).__init__(**kwargs)
def build(self, input_shape):
self.conv1 = layers.Conv2D(self.output_size, self.kernel_size, padding = 'same', kernel_initializer = 'he_normal') # Notice no activation. Added after layer norm
self.norm1 = layers.LayerNormalization()
self.relu1 = layers.ReLU()
output_shape = input_shape.as_list()
output_shape[-1] = self.output_size
# TODO This does not work right now
#self.last_output = tf.Variable(tf.fill(input_shape, 0.0), validate_shape=False, trainable=False)
#self.last_output = tf.Variable(tf.fill(tf.TensorShape(input_shape), 0.0))
def call(self, inputs, **kwargs):
lastx = self.last_output
x = layers.Concatenate()([inputs, lastx])
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
tf.keras.backend.update(self.last_output, x)
return x
感谢您的帮助!