如何在Tensorflow API r1.0中更改以下代码?

时间:2017-03-06 07:09:26

标签: tensorflow

我已将张量流从0.12升级到1.0。因此,我遇到了错误代码......我尝试解决错误,但我无法找到解决方案。 我希望您分享与此相关的知识和经验。谢谢。

self._initial_state = lstm_cell.zero_state(self.batch_size, tf.float32)
inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(axis=1, num_or_size_splits=4, value=pooled_concat)]

#previous code(v.0.11) : 
self._initial_state = lstm_cell.zero_state(self.batch_size, tf.float32)
inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, reduced, pooled_concat)] 


# -------- following code comes out the error--------- 
outputs, state = tf.nn.rnn(lstm_cell, inputs, initial_state=self._initial_state, sequence_length=self.real_len) 

# Above code creates a recurrent neural network specified by RNNCell cell in api 0.11. 
# tf.nn.rnn(cell, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None) 

#---------------------------------------------------

2 个答案:

答案 0 :(得分:1)

自API tf.nn.rnn版本1.0起删除。尝试使用tf.nn.dynamic_rnn。但请注意,此方法需要inputs为张量而不是张量列表(请参阅链接文档),因此您必须按照创建inputs变量的方式进行更改。

答案 1 :(得分:1)

outputs, state = tf.nn.rnn(lstm_cell, inputs, initial_state=self._initial_state, sequence_length=self.real_len)

应该(现在,在TF 1.0中):

outputs, state = tf.contrib.rnn.static_rnn(lstm_cell, inputs, initial_state=self._initial_state, sequence_length=self.real_len)

由于nn.rnn已移至tf.contrib