在Keras中定义自定义LSTM Cell?

时间:2019-01-17 08:03:54

标签: python tensorflow keras lstm

我将Keras与TensorFlow一起用作后端。如果我想对LSTM单元进行修改,例如“移除”输出门,该怎么办?这是一个乘法门,因此无论如何我都必须将其设置为固定值,以免乘以它。

1 个答案:

答案 0 :(得分:4)

首先,您应该定义own custom layer。如果需要一些直觉如何实现自己的单元,请参阅Keras存储库中的LSTMCell。例如。您的图层将是:

class MyLSTMCell(tf.keras.layers.Layer):

    def build():
       # define your own logic

    def call():
      # call your own logic

然后,使用tf.keras.layers.RNN来使用您的单元格:

x = tf.keras.layers.RNN(my_custom_cell)(inputs)