如何在Keras中使用“状态”变量/张量创建自定义图层?

时间:2020-03-08 16:12:57

标签: tensorflow keras recurrent-neural-network layer stateful

我想问您一些有关创建自定义图层的帮助。 我实际上想做的很简单:生成一个带有“有状态”变量的输出层,即每个张量值都会更新的张量。

为了使所有内容更加清晰,下面是我想做的一小段:

def call(self, inputs)

   c = self.constant
   m = self.extra_constant

   update = inputs*m + c 
   X_new = self.X_old + update 

   outputs = X_new

   self.X_old = X_new   

   return outputs

这里的想法很简单:

  • X_olddef__ init__(self, ...)中被初始化为0
  • update是根据图层输入而计算的
  • 计算出该图层的输出(即X_new
  • X_old的值设置为等于X_new,以便在下一个批次中,X_old不再等于0,而是等于上一个X_new批次。

我发现K.update可以完成工作,如示例所示:

 X_new = K.update(self.X_old, self.X_old + update)

这里的问题是,如果我尝试将图层的输出定义为:

outputs = X_new

return outputs

尝试model.fit()时,我将收到以下错误:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have 
gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

即使我强加layer.trainable = False并且没有为该层定义任何偏差或权重,我仍然会遇到此错误。另一方面,如果我只做self.X_old = X_new,则X_old的值不会得到更新。

你们有解决方案吗?我相信这不应该那么难,因为有状态的RNN也具有类似的功能。

在此先感谢您的帮助!

1 个答案:

答案 0 :(得分:2)

定义自定义图层有时会引起混乱。您重写的某些方法将被调用一次,但给您的印象是,就像许多其他OO库/框架一样,它们将被调用多次。

这是我的意思:当您定义一个图层并将其用于模型中时,您为覆盖call方法编写的python代码将不会在向前或向后传递中直接调用。相反,当您调用model.compile时,它仅被调用一次。它将python代码编译成一个计算图,而张量将在其中流动的图就是训练和预测期间的计算。

这就是为什么如果您想通过放置print语句来调试模型,那么它将不起作用;您需要使用tf.print向图形添加打印命令。

与您要拥有的状态变量的情况相同。除了简单地将old + update分配给new之外,您还需要调用Keras函数以将该操作添加到图形中。

请注意,张量是不可变的,因此您需要在tf.Variable方法中将状态定义为__init__

所以我相信这段代码更像您要寻找的东西:

class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(CustomLayer, self).__init__(**kwargs)
    self.state = tf.Variable(tf.zeros((3,3), 'float32'))
    self.constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
    self.extra_constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
    self.trainable = False

  def call(self, X):
    m = self.constant    
    c = self.extra_constant
    outputs = self.state + tf.matmul(X, m) + c
    tf.keras.backend.update(self.state, tf.reduce_sum(outputs, axis=0))

    return outputs