为什么我不能用lambda替换这个函数

时间:2017-12-04 22:25:29

标签: python tensorflow lambda

我正在写张量流中的RNN,我想将几​​个LSTM单元堆叠在一起。根据tensorflow教程,我应该这样做:

    def lstm_cell():
        return tf.contrib.rnn.BasicLSTMCell(lstm_size)
    stacked_lstm = tf.contrib.rnn.MultiRNNCell(
        [lstm_cell() for _ in range(number_of_layers)])

当我这样做时,事情按预期工作。但我想知道我是否可以用简单的lambda函数替换函数的定义......遗憾的是,这不起作用。我用上面代码替换的是:

    stacked_lstm = tf.contrib.rnn.MultiRNNCell([lambda:tf.contrib.rnn.
        BasicLSTMCell(lstm_size) for _ in range(number_of_layers)])

我认为这样可行,因为我对python中的“lambda”的理解是特别的,所以我可以替换为这样非常简单的函数定义一个单独的函数。我对lambda的理解是错误的吗?我在做最后一个时得到的错误信息是:

    AttributeError: 'function' object has no attribute 'zero_state'

我原以为堆叠LSTM的两种不同方法是等价的,但显然不是?

1 个答案:

答案 0 :(得分:3)

[lambda:tf.contrib.rnn.BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]

此处,lambda:tf.contrib.rnn.BasicLSTMCell(lstm_size) 仅定义一个函数,而不是将其称为。相反,您可以直接访问BasicLSTMCell()功能:

[tf.contrib.rnn.BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]

然后,您可以通过以下方式导入MultiRNNCellBasicLSTMCell来进一步缩短它:

from tensorflow.contrib.rnn import MultiRNNCell, BasicLSTMCell

lstm_stacks = [BasicLSTMCell(lstm_size) for _ in range(number_of_layers)]
stacked_lstm = MultiRNNCell(lstm_stacks)