在Tensorflow中的运行之间保存LSTM RNN状态

时间:2017-01-28 21:33:31

标签: python tensorflow

在Tensorflow中运行之间保存LSTM状态的最佳方法是什么?对于预测阶段,我需要一次传递一次数据,因为下一个时间步的输入依赖于前一个时间步的输出。

我使用了这篇文章的建议:Tensorflow, best way to save state in RNNs?并通过反复传递相同的输入而不运行优化器来测试它。如果我理解正确,如果输出每次都改变那么它就是保存状态但是如果它保持不变那么它就不是。结果是它第一次保存状态,但保持不变。

这是我的代码:

 pieces = data_generator.load_pieces(5)

 batches = 100
 sizes = [126, 122]
 steps = 128
 layers = 2

 x = tf.placeholder(tf.float32, shape=[batches, steps, sizes[0]])
 y_ = tf.placeholder(tf.float32, shape=[batches, steps, sizes[1]])

 W = tf.Variable(tf.random_normal([sizes[0], sizes[1]]))
 b = tf.Variable(tf.random_normal([sizes[1]]))

 layer = tf.nn.rnn_cell.BasicLSTMCell(sizes[0], forget_bias=0.0)
 lstm = tf.nn.rnn_cell.MultiRNNCell([layer] * layers)

 # ~~~~~ code from linked post ~~~~~
 def get_state_variables(batch_size, cell):
     # For each layer, get the initial state and make a variable out of it
     # to enable updating its value.
     state_variables = []
     for state_c, state_h in cell.zero_state(batch_size, tf.float32):
         state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
             tf.Variable(state_c, trainable=False),
             tf.Variable(state_h, trainable=False)))
     # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
     return tuple(state_variables)

 states = get_state_variables(batches, lstm)

 outputs, new_states = tf.nn.dynamic_rnn(lstm, x, initial_state=states, dtype=tf.float32)

 def get_state_update_op(state_variables, new_states):
     # Add an operation to update the train states with the last state tensors
     update_ops = []
     for state_variable, new_state in zip(state_variables, new_states):
         # Assign the new state to the state variables on this layer
         update_ops.extend([state_variable[0].assign(new_state[0]),
                            state_variable[1].assign(new_state[1])])
     # Return a tuple in order to combine all update_ops into a single operation.
     # The tuple's actual value should not be used.
     return tf.tuple(update_ops)

 update_op = get_state_update_op(states, new_states)
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

 output = tf.reshape(outputs, [-1, sizes[0]])
 y = tf.nn.sigmoid(tf.matmul(output, W) + b)
 y = tf.reshape(y, [-1, steps, sizes[1]])

 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), [1, 2]))
 # train_step = tf.train.AdadeltaOptimizer().minimize(cross_entropy)

 sess = tf.InteractiveSession()
 sess.run(tf.global_variables_initializer())
 batch_x, batch_y = data_generator.get_batch(pieces)
 for i in range(500):
     error, _ = sess.run([cross_entropy, update_op], feed_dict={x: batch_x, y_: batch_y})
     print str(i) + ': ' + str(error)

以下是错误:

  • 0:419.861
  • 1:419.756
  • 2:419.756
  • 3:419.756 ...

1 个答案:

答案 0 :(得分:0)

我推荐你几天前试过的this answer。效果很好。

顺便说一句,有办法避免将state_is_tuple设置为false

class CustomLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        # kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomLSTMCell, self).__init__(
            *args, **kwargs)  # create an lstm cell
        # change the output size to the state size
        self._output_size = np.sum(self._state_size)
        return returns

    def __call__(self, inputs, state):
        output, next_state = super(
            CustomLSTMCell, self).__call__(inputs, state)
        # return two copies of the state, instead of the output and the state
        return tf.reshape(next_state, shape=[1, -1]), next_state