如何更有效地在Tensorflow中运行图表?

时间:2018-05-17 12:53:58

标签: tensorflow

如何在不重新启动操作的情况下更有效地运行已保存的Tensorflow模型? 每当我调用预测方法时,获得分数结果需要时间。

def predict(self, msg):
tensor = self._convector.sentence_to_id_vector(msg)
if tensor is 0:
    return 0
else:
    graph = tf.Graph()
    with graph.as_default():
        sess = tf.Session()
        saver = tf.train.import_meta_graph("{}/model.ckpt.meta".format(self._model_dir))
        saver.restore(sess, ("{}/model.ckpt".format(self._model_dir)))
        input = graph.get_operation_by_name('input').outputs[0]
        seq_len = graph.get_operation_by_name('lengths').outputs[0]
        dropout_keep_prob = graph.get_operation_by_name('dropout_keep_prob').outputs[0]
        prediction = graph.get_operation_by_name('final_layer/softmax/predictions').outputs[0]
        model_seq_len = self._convector.sequence_len
        sample_len = self._convector.sample_len

        score = sess.run(prediction, feed_dict={input: tensor.reshape(1, model_seq_len),
                                                seq_len: [sample_len], dropout_keep_prob: 1.0})
        print('score result: [{0:.2f}, {1:.2f}]'.format(score[0, 0], score[0, 1]))
        return score

1 个答案:

答案 0 :(得分:0)

每次运行预测时都要导入图形和会话。

将图形和会话保存在内存中(使它们在对象或全局变量中保持可用)。运行新预测时,您只需执行sess.run行。

您目前一次运行1个预测。如果您同时运行一堆请求(基本上为预测进行批处理),则可以更有效地执行模型内部的矩阵乘法。这并不总是可行,但只要您能够将执行放在内存中,它确实提供了良好的性能。

您从检查点导入模型。这可能包括有利于培训的操作安置,但可能不适合预测。使用saved_model导出功能,它将清除它,允许客户端选择最佳位置。这应该允许您正确使用GPU。作为副作用,它会为您存储输入和输出张量的名称,这样您就不需要对它们进行硬编码。

最后,请确保您拥有正确版本的tensorflow。基本的没有GPU支持甚至AVX2支持(现代CPU上的特殊指令)。根据模型和工作量的不同,它们可能会有所不同,也可能会有所不同。