如何在TensorFlow C中测试加载模型文件的函数(像SavedModel这样的FrozenGraph)?

时间:2018-03-01 07:58:47

标签: c++ tensorflow

众所周知,TensorFlow的模型文件可以分为两类,frozen graphSavedModel。 保存的模型目录的结构可以是:

 variables/
        variables.data-?????-of-?????
        variables.index  
 saved_model.pb|saved_model.pbtxt

冻结的图表是单个文件:

frozon_graph.pb

我想在C 中运行预定义的图表。在 c_api.cc 中包含函数

  

TF_LoadSessionFromSavedModel

TF_Session* TF_LoadSessionFromSavedModel(
    const TF_SessionOptions* session_options, const TF_Buffer* run_options,
    const char* export_dir, const char* const* tags, int tags_len,
    TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
// the tensorflow/cc/saved_model:loader build target is Android friendly.
#ifdef __ANDROID__
  status->status = tensorflow::errors::Unimplemented(
      "Loading a SavedModel is not supported in Android. File a bug at "
      "https://github.com/tensorflow/tensorflow/issues if this feature is "
      "important to you");
  return nullptr;
#else
  mutex_lock l(graph->mu);
  if (!graph->name_map.empty()) {
    status->status = InvalidArgument("Graph is non-empty.");
    return nullptr;
  }

  RunOptions run_options_proto;
  if (run_options != nullptr && !run_options_proto.ParseFromArray(
                                    run_options->data, run_options->length)) {
    status->status = InvalidArgument("Unparseable RunOptions proto");
    return nullptr;
  }

  std::unordered_set<tensorflow::string> tag_set;
  for (int i = 0; i < tags_len; i++) {
    tag_set.insert(tensorflow::string(tags[i]));
  }

  tensorflow::SavedModelBundle bundle;
  status->status =
      tensorflow::LoadSavedModel(session_options->options, run_options_proto,
                                 export_dir, tag_set, &bundle);
  if (!status->status.ok()) return nullptr;

  // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
  // extends using GraphDefs. The Graph instance is different, but equivalent
  // to the one used to create the session.
  //
  // TODO(jhseu): When Session is modified to take Graphs instead of
  // GraphDefs, return the Graph generated in LoadSavedModel().
  TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
  GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
                            import_opts, nullptr, 0, status);
  TF_DeleteImportGraphDefOptions(import_opts);
  if (TF_GetCode(status) != TF_OK) return nullptr;

  if (meta_graph_def != nullptr) {
    status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
    if (!status->status.ok()) return nullptr;
  }

  TF_Session* session = new TF_Session(bundle.session.release(), graph);

  graph->num_sessions += 1;
  session->last_num_graph_nodes = graph->graph.num_node_ids();
  return session;
#endif  // __ANDROID__
}

但不包含

  

TF_LoadSessionFromFrozenGraph

所以我将功能添加到 c_api.cc

TF_Session* TF_LoadSessionFromFrozenGraph(
    const TF_SessionOptions* session_options, const char* frozenPbFile,
    TF_Graph* graph, TF_Status* status) 
{
    mutex_lock l(graph->mu);
    if (!graph->name_map.empty()) {
        status->status = InvalidArgument("Graph is non-empty.");
        return nullptr;
    }

    // read from frozen pb file
    tensorflow::GraphDef graph_def;
    Status load_graph_status =
        ReadBinaryProto(session_options->options.env, frozenPbFile, &graph_def);
    if (!load_graph_status.ok()) {
        status->status = load_graph_status;
        return nullptr;
    }

    // create Session
    tensorflow::Session* session = tensorflow::NewSession(session_options->options);
    Status session_create_status = session->Create(graph_def);
    if (!session_create_status.ok()) {
        status->status = session_create_status;
        return nullptr;
    }

    // import graph
    TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
    GraphImportGraphDefLocked(graph, graph_def, import_opts, nullptr, 0, status);
    TF_DeleteImportGraphDefOptions(import_opts);
    if (TF_GetCode(status) != TF_OK)
        return nullptr;

    // create TF_Session
    TF_Session* tf_session = new TF_Session(session, graph);
    graph->num_sessions += 1;
    tf_session->last_num_graph_nodes = graph->graph.num_node_ids();
    return tf_session;
}

现在我要测试函数TF_LoadSessionFromFrozenGraph。 我发现 TF_LoadSessionFromSavedModel 的测试代码位于 c_api_test.cc

    TEST(CAPI, SavedModel) {
  // Load the saved model.
  const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
  const string saved_model_dir = tensorflow::io::JoinPath(
      tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
  TF_SessionOptions* opt = TF_NewSessionOptions();
  TF_Buffer* run_options = TF_NewBufferFromString("", 0);
  TF_Buffer* metagraph = TF_NewBuffer();
  TF_Status* s = TF_NewStatus();
  const char* tags[] = {tensorflow::kSavedModelTagServe};
  TF_Graph* graph = TF_NewGraph();
  TF_Session* session = TF_LoadSessionFromSavedModel(
      opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
  TF_DeleteBuffer(run_options);
  TF_DeleteSessionOptions(opt);
  tensorflow::MetaGraphDef metagraph_def;
  metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
  TF_DeleteBuffer(metagraph);

  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  CSession csession(session);

  // Retrieve the regression signature from meta graph def.
  const auto signature_def_map = metagraph_def.signature_def();
  const auto signature_def = signature_def_map.at("regress_x_to_y");

  const string input_name =
      signature_def.inputs().at(tensorflow::kRegressInputs).name();
  const string output_name =
      signature_def.outputs().at(tensorflow::kRegressOutputs).name();

  // Write {0, 1, 2, 3} as tensorflow::Example inputs.
  Tensor input(tensorflow::DT_STRING, TensorShape({4}));
  for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
    tensorflow::Example example;
    auto* feature_map = example.mutable_features()->mutable_feature();
    (*feature_map)["x"].mutable_float_list()->add_value(i);
    input.flat<string>()(i) = example.SerializeAsString();
  }

  const tensorflow::string input_op_name =
      tensorflow::ParseTensorName(input_name).first.ToString();
  TF_Operation* input_op =
      TF_GraphOperationByName(graph, input_op_name.c_str());
  ASSERT_TRUE(input_op != nullptr);
  csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

  const tensorflow::string output_op_name =
      tensorflow::ParseTensorName(output_name).first.ToString();
  TF_Operation* output_op =
      TF_GraphOperationByName(graph, output_op_name.c_str());
  ASSERT_TRUE(output_op != nullptr);
  csession.SetOutputs({output_op});
  csession.Run(s);
  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

  TF_Tensor* out = csession.output_tensor(0);
  ASSERT_TRUE(out != nullptr);
  EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
  EXPECT_EQ(2, TF_NumDims(out));
  EXPECT_EQ(4, TF_Dim(out, 0));
  EXPECT_EQ(1, TF_Dim(out, 1));
  float* values = static_cast<float*>(TF_TensorData(out));
  // These values are defined to be (input / 2) + 2.
  EXPECT_EQ(2, values[0]);
  EXPECT_EQ(2.5, values[1]);
  EXPECT_EQ(3, values[2]);
  EXPECT_EQ(3.5, values[3]);

  csession.CloseAndDelete(s);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_DeleteGraph(graph);
  TF_DeleteStatus(s);
}

TEST(CAPI, SavedModelNullArgsAreValid) {
  const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
  const string saved_model_dir = tensorflow::io::JoinPath(
      tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
  TF_SessionOptions* opt = TF_NewSessionOptions();
  TF_Status* s = TF_NewStatus();
  const char* tags[] = {tensorflow::kSavedModelTagServe};
  TF_Graph* graph = TF_NewGraph();
  // NULL run_options and meta_graph_def should work.
  TF_Session* session = TF_LoadSessionFromSavedModel(
      opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_DeleteSessionOptions(opt);
  TF_CloseSession(session, s);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_DeleteSession(session, s);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_DeleteGraph(graph);
  TF_DeleteStatus(s);
}

如何模仿savedmodel test的代码编写冻结图测试代码

我是c的新手,如果我的问题很愚蠢,请原谅我。

0 个答案:

没有答案