我如何实际执行保存的TensorFlow模型?

时间:2016-04-08 20:34:39

标签: neural-network tensorflow deep-learning lstm recurrent-neural-network

Tensorflow新手在这里。我正在尝试建立一个RNN。我的输入数据是一组大小为instance_size的矢量实例,表示每个时间步的一组粒子的(x,y)位置。 (由于实例已经具有语义内容,因此它们不需要嵌入。)目标是学习在下一步预测粒子的位置。

RNN tutorial之后并略微调整包含的RNN代码,我创建了一个或多或少像这样的模型(省略一些细节):

inputs, self._input_data = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])
self._targets = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])

with tf.variable_scope("lstm_cell", reuse=True):
  lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size, forget_bias=0.0)
  if is_training and config.keep_prob < 1:
    lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
        lstm_cell, output_keep_prob=config.keep_prob)
  cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)

self._initial_state = cell.zero_state(batch_size, tf.float32)

from tensorflow.models.rnn import rnn
inputs = [tf.squeeze(input_, [1])
          for input_ in tf.split(1, num_steps, inputs)]
outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)

output = tf.reshape(tf.concat(1, outputs), [-1, hidden_size])
softmax_w = tf.get_variable("softmax_w", [hidden_size, instance_size])
softmax_b = tf.get_variable("softmax_b", [instance_size])
logits = tf.matmul(output, softmax_w) + softmax_b
loss = position_squared_error_loss(
    tf.reshape(logits, [-1]),
    tf.reshape(self._targets, [-1]),
)
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = state

然后我创建一个saver = tf.train.Saver(),迭代数据以使用给定的run_epoch()方法训练它,并用saver.save()写出参数。到目前为止,非常好。

但是我如何使用训练有素的模型呢?教程此时停止。从the docs on tf.train.Saver.restore()开始,为了回读变量,我需要设置与我保存变量时运行的完全相同的图,或者选择性地恢复特定变量。无论哪种方式,这意味着我的新模型将需要大小为batch_size x num_steps x instance_size的输入。但是,我现在想要的是在大小为num_steps x instance_size的输入上对模型进行单个前向传递,并读出单个instance_size大小的结果(下一个时间步的预测);换句话说,我想创建一个接受不同尺寸张量的模型,而不是我训练过的张量。我可以通过将现有模型传递给我的预期数据batch_size次来对其进行处理,但这似乎不是最佳做法。最好的方法是什么?

1 个答案:

答案 0 :(得分:2)

您必须使用batch_size = 1创建一个结构相同的新图表,并使用tf.train.Saver.restore()导入已保存的变量。您可以在ptb_word_lm.py中查看他们如何定义具有可变批量大小的多个模型:https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/rnn/ptb/ptb_word_lm.py

因此,您可以拥有一个单独的文件,例如,使用所需的batch_size实例化图形,然后恢复已保存的变量。然后你可以执行你的图表。