layer_1 = tf.layers.dense(inputs=layer_c, units=512, activation=tf.nn.tanh, name='layer1')
layer_2 = tf.layers.dense(inputs=1, units=512, activation=tf.nn.tanh, name='layer2')
在这里,我的layer_2输出为[batch_size,512]。我需要通过单个lstm单元发送此layer_2输出。但是,当我尝试 tf.nn.static_rnn 时,出现错误,提示我的输入应为序列。如何执行此任务?
答案 0 :(得分:1)
在documentation for static_rnn
中,inputs
参数需要一个列表:
inputs
:输入的长度T列表,每个输入为[batch_size,input_size]形状的张量或此类元素的嵌套元组。
在您的情况下为T==1
,因此您可以将包含上一层的单元素列表传递给它。要以某种方式跟踪内部单元格和隐藏状态,以便跨时间步保持它们,可以添加其他占位符,然后使用static_rnn
属性将其传递给initial_state
。因为cell.state_size
是LSTM单元((cell_state, hidden_state)
)的元组,所以我们必须为该属性传递一个元组,然后将为输出状态返回一个元组。
这是一个基于您的代码的最小工作示例,只是在每个时间步输入一个占位符作为输入,并跨时间跟踪内部状态:
import tensorflow as tf
import numpy as np
num_timesteps = 6
batch_size = 3
num_input_feats = 100
num_lstm_units = 5
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units)
input_x = tf.placeholder(tf.float32, [None, num_input_feats], name='input')
input_c_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.c], name='c_state')
input_h_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.h], name='h_state')
layer_1 = tf.layers.dense(inputs=input_x, units=512, activation=tf.nn.tanh, name='layer1')
layer_2 = tf.layers.dense(inputs=layer_1, units=512, activation=tf.nn.tanh, name='layer2')
layer_2_next, next_state = tf.nn.static_rnn(lstm_cell, [layer_2], dtype=tf.float32,
initial_state=(input_c_state, input_h_state))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# initialize the internal cell state and hidden state to zero
cur_c_state = np.zeros([batch_size, lstm_cell.state_size.c], dtype="float32")
cur_h_state = np.zeros([batch_size, lstm_cell.state_size.h], dtype="float32")
for i in range(num_timesteps):
# here is your single timestep of input
cur_x = np.ones([batch_size, num_input_feats], dtype="float32")
y_out, out_state = sess.run([layer_2_next, next_state],
feed_dict={input_x: cur_x,
input_c_state: cur_c_state,
input_h_state: cur_h_state})
cur_c_state, cur_h_state = out_state # pass states along to the next timestep
print (y_out) # here is your single timestep of output