我正在写张量流中的RNN,我想将几个LSTM单元堆叠在一起。根据tensorflow教程,我应该这样做:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(lstm_size)
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(number_of_layers)])
当我这样做时,事情按预期工作。但我想知道我是否可以用简单的lambda函数替换函数的定义......遗憾的是,这不起作用。我用上面代码替换的是:
stacked_lstm = tf.contrib.rnn.MultiRNNCell([lambda:tf.contrib.rnn.
BasicLSTMCell(lstm_size) for _ in range(number_of_layers)])
我认为这样可行,因为我对python中的“lambda”的理解是特别的,所以我可以替换为这样非常简单的函数定义一个单独的函数。我对lambda的理解是错误的吗?我在做最后一个时得到的错误信息是:
AttributeError: 'function' object has no attribute 'zero_state'
我原以为堆叠LSTM的两种不同方法是等价的,但显然不是?
答案 0 :(得分:3)
[lambda:tf.contrib.rnn.BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]
此处,lambda:tf.contrib.rnn.BasicLSTMCell(lstm_size)
仅定义一个函数,而不是将其称为。相反,您可以直接访问BasicLSTMCell()
功能:
[tf.contrib.rnn.BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]
然后,您可以通过以下方式导入MultiRNNCell
和BasicLSTMCell
来进一步缩短它:
from tensorflow.contrib.rnn import MultiRNNCell, BasicLSTMCell
lstm_stacks = [BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]
stacked_lstm = MultiRNNCell(lstm_stacks)