我曾经使用以下方法在TensorFlow的0.8版本中创建RNN网络:
from tensorflow.python.ops import rnn
# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Get lstm cell output
outputs, states = rnn.rnn(cell=lstm_cell, inputs=x, dtype=tf.float32)
rnn.rnn()
不再可用,并且听起来已被移至tf.contrib
。从BasicLSTMCell
创建RNN网络的确切代码是什么?
或者,如果我有堆叠的LSTM,
lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_size, forget_bias=0.0)
stacked_lstm = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers)
outputs, new_state = tf.nn.rnn(stacked_lstm, inputs, initial_state=_initial_state)
新版TensorFlow中tf.nn.rnn
的替代品是什么?
答案 0 :(得分:13)
tf.nn.rnn
相当于tf.nn.static_rnn
。
注意:在version 1.2 of TensorFlow之前,名称空间 tf.nn.static_rnn
不存在,只有tf.contrib.rnn.static_rnn
(现在是别名 tf.nn.static_rnn
)。
答案 1 :(得分:2)