我最近将我的tesnorflow从Rev8升级到Rev12。在Rev8中,rnn_cell.LSTMCell中的默认“state_is_tuple”标志设置为False,因此我使用列表初始化了LSTM Cell,请参阅下面的代码。
#model definition
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)
#init_state place holder and feed_dict
def add_placeholders(self):
self.init_state = tf.placeholder("float", [None, self.cell_size])
def get_feed_dict(self, data, label):
feed_dict = {self.input_data: data,
self.input_label: reg_label,
self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
return feed_dict
在Rev12中,默认的“state_is_tuple”标志设置为True,为了使我的旧代码工作,我必须明确地将标志变为False。但是,现在我收到了tensorflow的警告说:
“使用连接状态较慢,很快就会被弃用。 使用state_is_tuple = True“
我尝试通过将self.init_state的占位符定义更改为以下内容来使用元组初始化LSTM单元格:
self.init_state = tf.placeholder("float", (None, self.cell_size))
但现在我收到一条错误消息:
“'Tensor'对象不可迭代”
有谁知道如何使这项工作?
答案 0 :(得分:1)
现在使用cell.zero_state
向LSTM提供“零状态”要简单得多。您不需要明确地将初始状态定义为占位符。将其定义为张量,并在需要时提供它。这是它的工作原理,
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)
如果你想提供一些其他值作为初始状态,比如说next_state = states[-1]
,在你的会话中计算并在feed_dict
中传递它 -
feed_dict[self.initial_state] = next_state
在您的问题中,lstm_cell.zero_state()
应该足够了。
无关,但请记住,您可以在Feed字典中传递张量和占位符!这就是self.initial_state
在上面示例中的工作方式。有关工作示例,请查看PTB Tutorial。