如何使用c ++在tensorflow中保存模型

时间:2017-11-19 09:55:46

标签: c++ tensorflow model save

如何使用c ++在Tensorflow中保存模型?我搜索了谷歌和百度,但没有找到任何解决方案。然后我阅读了tensorflow的api文档,并且介绍了关于C ++的更少介绍

2 个答案:

答案 0 :(得分:1)

模型保存仅在Python中实现。目前无法使用C ++ API保存模型。 C ++ API允许您加载和使用模型,而不是训练或保存它们。

答案 1 :(得分:0)

假设您对tensorflow C ++ API有基本的了解,并且知道如何使用C ++ API构造图形。您可以使用以下两个功能:

  1. tensorflow::WriteTextProto():您可以从tensorflow::GraphDef获取tensorflow::Scope::ToGraphDef()(代表您定义的所有运算,例如加,乘,均值....等),保存{ {1}}发送文本protobuf文件

  2. tensorflow::GraphDef将参数矩阵的当前状态保存到外部文件(检查点),虽然有点复杂,但对我来说效果很好

首先,您必须通过调用tensorflow::checkpoint::TensorSliceWriter来获得训练有素的参数,这将向tensorflow::Session::Run返回参数矩阵的列表(请参见下面的示例):

output_tensor

其中上面的std::vector<tensorflow::Tensor> output_tensor; tensorflow::Session::Run({}, {"name_of_param_mtx_1", "name_of_param_mtx_2",}, {}, &output_tensor); name_of_param_mtx_1应该是name_of_param_mtx_2中参数矩阵的名称,例如

tensorflow::Variable

然后您需要为auto name_of_param_mtx_1 = tensorflow::ops::Variable (root.WithOpName("name_of_param_mtx_1"), {7, 17}, tensorflow::DT_FLOAT); 准备以下内容:

    调用tensorflow::checkpoint::TensorSliceWriter
  • 参数原始数据的基地址
  • 每个tensorflow::Tensor.tensor_data().data()
  • 形状,通过调用tensorflow::Tensor。例如7x17 2D参数矩阵,NUM_DIMENSION可以是0和1,其中tensorflow :: Tensor :: dim_size(0)是7而tensorflow :: Tensor :: dim_size(1)是17。
  • 此检查点的名称,该名称必须与一个文件中其他检查点的名称唯一
  • 通过调用tensorflow::Tensor::dim_size(NUM_DIMENSION)创建tensorflow::TensorSlice,看来tensorflow::TensorSlice::ParseOrDie("-:-")的唯一参数将在内部进行分析,例如tensorflow::TensorSlice::ParseOrDie表示取矩阵的所有项。如果用户只想要一部分训练有素的参数矩阵,例如只占所有行的第二列,则字符串参数可能是-:-,我还没有弄清楚-:2的这种高级用法。

希望有帮助。