使用带有sequence_length的tf.contrib.rnn.static_rnn时的初始化错误

时间:2017-11-06 06:16:57

标签: python tensorflow

以下代码可以正常使用:

with tf.variable_scope("lstm"):
    cell = tf.contrib.rnn.BasicLSTMCell(num_units=512)
    outputs, final_state = tf.contrib.rnn.static_rnn(
            cell=cell,
            inputs=x,
            dtype=tf.int32)

然而,当我给它sequence_length时,它会抱怨:

ValueError: Initializer for variable lstm/rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

我已在doc中读到"如果提供sequence_length向量,则执行动态计算。"我想知道为什么它与"循环或条件"有关。我应该如何解决这个问题。非常感谢!

编辑:这是破解的代码:

x = tf.placeholder(dtype=tf.float32, shape=[None, sentence_len, feature_len])
x_unstacked = tf.unstack(x, sentence_len, 1)
xlen = tf.placeholder(dtype=tf.int32, shape=[None])

with tf.variable_scope("lstm"):
    cell = tf.contrib.rnn.BasicLSTMCell(num_units=512)
    outputs, final_state = tf.contrib.rnn.static_rnn(
            cell=cell,
            inputs=x_unstacked,
            dtype=tf.int32,
            sequence_length=xlen)

0 个答案:

没有答案