众所周知,TensorFlow的模型文件可以分为两类,frozen graph或SavedModel。 保存的模型目录的结构可以是:
variables/
variables.data-?????-of-?????
variables.index
saved_model.pb|saved_model.pbtxt
冻结的图表是单个文件:
frozon_graph.pb
我想在C 中运行预定义的图表。在 c_api.cc 中包含函数
TF_LoadSessionFromSavedModel
TF_Session* TF_LoadSessionFromSavedModel(
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
const char* export_dir, const char* const* tags, int tags_len,
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
// the tensorflow/cc/saved_model:loader build target is Android friendly.
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Loading a SavedModel is not supported in Android. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
return nullptr;
#else
mutex_lock l(graph->mu);
if (!graph->name_map.empty()) {
status->status = InvalidArgument("Graph is non-empty.");
return nullptr;
}
RunOptions run_options_proto;
if (run_options != nullptr && !run_options_proto.ParseFromArray(
run_options->data, run_options->length)) {
status->status = InvalidArgument("Unparseable RunOptions proto");
return nullptr;
}
std::unordered_set<tensorflow::string> tag_set;
for (int i = 0; i < tags_len; i++) {
tag_set.insert(tensorflow::string(tags[i]));
}
tensorflow::SavedModelBundle bundle;
status->status =
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
export_dir, tag_set, &bundle);
if (!status->status.ok()) return nullptr;
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
// extends using GraphDefs. The Graph instance is different, but equivalent
// to the one used to create the session.
//
// TODO(jhseu): When Session is modified to take Graphs instead of
// GraphDefs, return the Graph generated in LoadSavedModel().
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
import_opts, nullptr, 0, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (meta_graph_def != nullptr) {
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
if (!status->status.ok()) return nullptr;
}
TF_Session* session = new TF_Session(bundle.session.release(), graph);
graph->num_sessions += 1;
session->last_num_graph_nodes = graph->graph.num_node_ids();
return session;
#endif // __ANDROID__
}
但不包含
TF_LoadSessionFromFrozenGraph
所以我将功能添加到 c_api.cc 。
TF_Session* TF_LoadSessionFromFrozenGraph(
const TF_SessionOptions* session_options, const char* frozenPbFile,
TF_Graph* graph, TF_Status* status)
{
mutex_lock l(graph->mu);
if (!graph->name_map.empty()) {
status->status = InvalidArgument("Graph is non-empty.");
return nullptr;
}
// read from frozen pb file
tensorflow::GraphDef graph_def;
Status load_graph_status =
ReadBinaryProto(session_options->options.env, frozenPbFile, &graph_def);
if (!load_graph_status.ok()) {
status->status = load_graph_status;
return nullptr;
}
// create Session
tensorflow::Session* session = tensorflow::NewSession(session_options->options);
Status session_create_status = session->Create(graph_def);
if (!session_create_status.ok()) {
status->status = session_create_status;
return nullptr;
}
// import graph
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
GraphImportGraphDefLocked(graph, graph_def, import_opts, nullptr, 0, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK)
return nullptr;
// create TF_Session
TF_Session* tf_session = new TF_Session(session, graph);
graph->num_sessions += 1;
tf_session->last_num_graph_nodes = graph->graph.num_node_ids();
return tf_session;
}
现在我要测试函数TF_LoadSessionFromFrozenGraph。 我发现 TF_LoadSessionFromSavedModel 的测试代码位于 c_api_test.cc
TEST(CAPI, SavedModel) {
// Load the saved model.
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};
TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_LoadSessionFromSavedModel(
opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
TF_DeleteBuffer(run_options);
TF_DeleteSessionOptions(opt);
tensorflow::MetaGraphDef metagraph_def;
metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
TF_DeleteBuffer(metagraph);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
CSession csession(session);
// Retrieve the regression signature from meta graph def.
const auto signature_def_map = metagraph_def.signature_def();
const auto signature_def = signature_def_map.at("regress_x_to_y");
const string input_name =
signature_def.inputs().at(tensorflow::kRegressInputs).name();
const string output_name =
signature_def.outputs().at(tensorflow::kRegressOutputs).name();
// Write {0, 1, 2, 3} as tensorflow::Example inputs.
Tensor input(tensorflow::DT_STRING, TensorShape({4}));
for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
tensorflow::Example example;
auto* feature_map = example.mutable_features()->mutable_feature();
(*feature_map)["x"].mutable_float_list()->add_value(i);
input.flat<string>()(i) = example.SerializeAsString();
}
const tensorflow::string input_op_name =
tensorflow::ParseTensorName(input_name).first.ToString();
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
csession.SetOutputs({output_op});
csession.Run(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
EXPECT_EQ(2, TF_NumDims(out));
EXPECT_EQ(4, TF_Dim(out, 0));
EXPECT_EQ(1, TF_Dim(out, 1));
float* values = static_cast<float*>(TF_TensorData(out));
// These values are defined to be (input / 2) + 2.
EXPECT_EQ(2, values[0]);
EXPECT_EQ(2.5, values[1]);
EXPECT_EQ(3, values[2]);
EXPECT_EQ(3.5, values[3]);
csession.CloseAndDelete(s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI, SavedModelNullArgsAreValid) {
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};
TF_Graph* graph = TF_NewGraph();
// NULL run_options and meta_graph_def should work.
TF_Session* session = TF_LoadSessionFromSavedModel(
opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteSessionOptions(opt);
TF_CloseSession(session, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteSession(session, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
如何模仿savedmodel test的代码编写冻结图测试代码?
我是c的新手,如果我的问题很愚蠢,请原谅我。