使用tensorflow c ++ api运行会话比使用python

时间:2017-05-10 12:37:35

标签: tensorflow

我正在尝试使用tensorflow c ++ api(仅限CPU)运行SqueezeDet。我已经冻结了张量流图并从C ++加载它。虽然在检测质量方面一切都很好,但性能比python慢​​得多。可能是什么原因?

简化,我的代码如下:

  int main (int argc, const char * argv[])
  {
    // Initializing graph 
    tensorflow::GraphDef graph_def;
    // Folder in which graph data is located
    string graph_file_name = "Model/graph.pb";
    // Loading graph 
    tensorflow::Status graph_loaded_status =  ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
    if (!graph_loaded_status.ok())
    {
      cout << graph_loaded_status.ToString() << endl;
      return 1;
    }
    unique_ptr<tensorflow::Session> session_sqdet(tensorflow::NewSession(tensorflow::SessionOptions()));
    tensorflow::Status session_create_status = session_sqdet->Create(graph_def);
    if (!session_create_status.ok())
    {
      cout << "Session create status: fail." << endl;
      return 1;
    }
    while ()
    {
      /* create & preprocess batch */

      session.Run({{ "image_input", input_tensor}, {"keep_prob", prob_tensor}}, {"probability/score", "bbox/trimming/bbox"}, {}, &final_output);

      /* do some postprocessing */
    }
  }

我尝试过:

1)使用优化标志 - 全部开启,没有警告。

2)使用批处理:性能提高,但python和C ++之间的差距仍然很大(运行会话需要1s vs 2.4s,batch_size = 20)。

任何帮助都将受到高度赞赏。

1 个答案:

答案 0 :(得分:1)

我花了很多时间在这个问题上(大多数是因为我犯了愚蠢的错误),但我终于解决了它。现在我想在这里发布我的经验,因为它可能有用。

所以这些步骤我建议跟随面对同一问题的人(其中一些非常明显):

0)正确进行分析!确保您在多核/ GPU /您拥有的任何设置中使用可靠的工具。

1)检查tensorflow和所有相关的包是否都是在所有优化的基础上构建的。

2)冻结后优化图形。

3)如果您在训练和推理期间使用不同的批量大小,请确保已删除模型中的所有依赖项!请注意,否则您不会在结果质量方面出现错误消息或表现更差,您只会出现神秘的减速!