使用已保存的Tensorflow Estimator和C ++ API

时间:2016-12-24 16:59:33

标签: python c++ tensorflow

我已按照https://www.tensorflow.org/versions/r0.11/tutorials/estimators/中的描述在Python中编写了鲍鱼估算器。我希望保存估算器的状态,然后用C ++加载它并用它来进行预测。

要从Python保存,我使用model_dir构造函数中的tf.contrib.learn.Estimator参数,该参数创建(文本)protobuf文件和多个检查点文件。然后我使用freeze_graph.py工具(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)将检查点和protobuf文件合并到一个独立的GraphDef文件中。

我使用C ++ API加载此文件,将一些输入值加载到Tensor中,然后运行会话。 protobuf文件中的输入节点称为“输入”,输出节点称为“输出”,两者都是占位符节点。

// ...
std::vector<std::pair<string, tensorflow::Tensor>> inputs = 
{
    {"input", inputTensor}
};

std::vector<tensorflow::Tensor> outputs;

status = pSession->Run(inputs, {"output"}, {}, &outputs);

但是,由于输出节点是占位符节点,因此需要输入值,因此会失败。但是你不能同时提供和获取节点值,因此我无法访问估算器的输出。为什么输出节点是占位符节点?

从Python中保存经过训练的估算器并加载它以便在C ++中进行预测的最佳方法是什么?

0 个答案:

没有答案