我想使用Tensorflow RNN(LSTM)从后续序列中预测标签。我想使训练(以及推理)的计算效率更高(也许是实时的)。
我已经实现了传递上一个批次中RNN的初始状态的功能。我想在Sample(一个序列)级别上执行此操作。我的代码:
batch_size = 2
tf.reset_default_graph()
input_ph = tf.placeholder(tf.float32, [None, None], name='inputs')
input_expanded = tf.expand_dims(input_ph, axis=2)
labels_ph = tf.placeholder(tf.float32, [None, 1], name='targets')
tf.random.set_random_seed(1234)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=2)
initial_state = tf.Variable(rnn_cell.zero_state(batch_size, dtype=tf.float32), trainable=False)
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_expanded,
initial_state=initial_state,
dtype=tf.float32)
state_update = tf.assign(initial_state, state)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
inp = [[1,1,1],[1,1,1]]
#updating initial state (last state from previous batch)
out, sta, _ = sess.run([outputs, state, state_update], feed_dict={input_ph: inp})
是否可以为批处理中的每个样本初始化初始化状态?
例如。 (batchsize = 2):