我想像在keras层中一样实现自定义LSTM单元。实际上该实现存在于tensorflow中,所以我想知道是否可以将其包装为keras层并在模型中调用它。
我发现官方documentation过于简化,以至于看不到如何构建自定义RNN层。 here和here也有类似的问题,但它们似乎尚未解决。
预先感谢您的帮助!
答案 0 :(得分:0)
根据我的理解,您应该只能够在类层的init()中初始化单元格,然后在调用方法内部使用您的输入对其进行引用。
例如:
class MySimpleLayer(Layer):
def __init__(self, lstm_size):
super(MySimpleLayer, self).__init__()
self.lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
def call(self, batch, state):
return self.lstm(batch, state)
layer = MySimpleLayer(lstm_size)
logits = layer(batch, state)
此实现是最基本的,因此对于更复杂的用例,您可能需要研究build()和compute_output_shape()方法。
答案 1 :(得分:0)
自问题发布以来,现在tensorflow的文档可能已有改善。
您可能需要检查this guide或this SO answer以供参考。