如何在keras中包装Tensorflow RNNCell?

时间:2018-12-12 21:26:24

标签: python tensorflow keras rnn keras-layer

我想像在keras层中一样实现自定义LSTM单元。实际上该实现存在于tensorflow中,所以我想知道是否可以将其包装为keras层并在模型中调用它。

我发现官方documentation过于简化,以至于看不到如何构建自定义RNN层。 herehere也有类似的问题,但它们似乎尚未解决。

预先感谢您的帮助!

2 个答案:

答案 0 :(得分:0)

根据我的理解,您应该只能够在类层的init()中初始化单元格,然后在调用方法内部使用您的输入对其进行引用。

例如:

class MySimpleLayer(Layer):
  def __init__(self, lstm_size):
    super(MySimpleLayer, self).__init__()
    self.lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)

  def call(self, batch, state):
    return self.lstm(batch, state)

layer = MySimpleLayer(lstm_size)
logits = layer(batch, state)

此实现是最基本的,因此对于更复杂的用例,您可能需要研究build()和compute_output_shape()方法。

答案 1 :(得分:0)

自问题发布以来,现在tensorflow的文档可能已有改善。

您可能需要检查this guidethis SO answer以供参考。