以下代码来自tensorflow服务API:
// Implementation of Predict using the SavedModel SignatureDef format.
Status SavedModelPredict(const RunOptions& run_options, ServerCore* core,
const PredictRequest& request,
PredictResponse* response) {
// Validate signatures.
ServableHandle<SavedModelBundle> bundle;
TF_RETURN_IF_ERROR(core->GetServableHandle(request.model_spec(), &bundle));
const string signature_name = request.model_spec().signature_name().empty()
? kDefaultServingSignatureDefKey
: request.model_spec().signature_name();
auto iter = bundle->meta_graph_def.signature_def().find(signature_name);
if (iter == bundle->meta_graph_def.signature_def().end()) {
return errors::FailedPrecondition(strings::StrCat(
"Serving signature key \"", signature_name, "\" not found."));
}
SignatureDef signature = iter->second;
MakeModelSpec(request.model_spec().name(), signature_name,
bundle.id().version, response->mutable_model_spec());
std::vector<std::pair<string, Tensor>> input_tensors;
std::vector<string> output_tensor_names;
std::vector<string> output_tensor_aliases;
TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
&output_tensor_names,
&output_tensor_aliases));
std::vector<Tensor> outputs;
RunMetadata run_metadata;
TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
output_tensor_names, {}, &outputs,
&run_metadata));
return PostProcessPredictionResult(signature, output_tensor_aliases, outputs,
response);
}
此代码使用存储的模型运行预测。在我的例子中,这个存储的模型是RNN。
在实际执行预测的以下行之后:
TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
output_tensor_names, {}, &outputs,
&run_metadata));
我想将RNN的状态保存到文件/内存中,以便我可以在每次预测后的日期访问它们。我假设可以通过变量访问这些状态:
bundle->meta_graph_def
但目前尚不清楚如何具体访问RNN的状态值,然后将其保存到文件中。
答案 0 :(得分:0)
您必须通过会话获取值,而不是事后。在你的行
bundle->session->Run(run_options, input_tensors,
output_tensor_names, {}, &outputs,
&run_metadata)
您不仅应该要求预测的结果,还要求状态张量。请注意,在对session->Run
的特定调用之后,不存储状态的值,仅存储变量,RNN的状态是计算值,并在返回请求的结果后删除。
我不使用c ++接口,所以请原谅我在python中提供代码示例,希望它仍然有用(我敢打赌这不是你第一次不得不忍受阅读python示例),在python中,这看起来像:
prediction, state = sess.run([prediction_tensor, state_tensor], feed_dict=...)