我目前在tensorflow中有一系列链接在一起的RNN的代码。我没有使用MultiRNN,因为我稍后会对每一层的输出做一些事情。
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
time_outputs = [None] * TIME_STEPS
for t in range(TIME_STEPS):
rnn_input = getTimeStep(rnn_outputs[r - 1], t)
time_outputs[t], state = rnn_func(rnn_input, state)
time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
scope.reuse_variables()
rnn_outputs[r] = tf.concat(1, time_outputs)
目前我有固定数量的时间步骤。但是我想把它改成只有一个时间步,但要记住批次之间的状态。因此,我需要为每个图层创建一个状态变量,并为每个图层指定最终状态。这样的事情。
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
saved_state = tf.get_variable('saved_state', ...)
rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
saved_state = tf.assign(saved_state, state)
然后,对于每个图层,我需要在sess.run函数中评估已保存的状态以及调用我的训练函数。我需要为每个rnn层执行此操作。这看起来很麻烦。我需要跟踪每个保存的状态并在运行中对其进行评估。此外,运行将需要将状态从我的GPU复制到主机内存,这将是低效且不必要的。有没有更好的方法呢?
答案 0 :(得分:22)
这是通过定义状态变量来更新LSTM的初始状态的代码,state_is_tuple=True
。它还支持多个层。
我们定义了两个函数 - 一个用于获取具有初始零状态的状态变量和一个用于返回操作的函数,我们可以将其传递给session.run
以使用LSTM更新状态变量'最后隐藏的状态。
def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
我们可以使用它来更新每批后的LSTM状态。请注意,我使用tf.nn.dynamic_rnn
进行展开:
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
与this answer的主要区别在于state_is_tuple=True
使LSTM的状态成为包含两个变量(单元状态和隐藏状态)的LSTMStateTuple,而不仅仅是单个变量。使用多个层然后使LSTM状态成为LSTMStateTuples的元组 - 每层一个。
使用经过训练的模型进行预测/解码时,您可能希望将状态重置为零。然后,您可以使用此功能:
def get_state_reset_op(state_variables, cell, batch_size):
# Return an operation to set each variable in a list of LSTMStateTuples to zero
zero_states = cell.zero_state(batch_size, tf.float32)
return get_state_update_op(state_variables, zero_states)
例如上面的内容:
reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})
答案 1 :(得分:2)
我现在使用tf.control_dependencies保存RNN状态。这是一个例子。
saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
rnn_output, states = rnn(last_output, saved_states)
with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
dense_input = tf.concat(1, (last_output, rnn_output))
dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
last_output = dense_output + last_output
我只是确保我的图表的一部分依赖于保存状态。
答案 2 :(得分:2)