如何从TF中的Graph保存RNN的状态?

时间:2018-03-28 20:22:10

标签: c++ tensorflow

以下代码来自tensorflow服务API:

// Implementation of Predict using the SavedModel SignatureDef format.
Status SavedModelPredict(const RunOptions& run_options, ServerCore* core,
                         const PredictRequest& request,
                         PredictResponse* response) {
  // Validate signatures.
  ServableHandle<SavedModelBundle> bundle;
  TF_RETURN_IF_ERROR(core->GetServableHandle(request.model_spec(), &bundle));

  const string signature_name = request.model_spec().signature_name().empty()
                                    ? kDefaultServingSignatureDefKey
                                    : request.model_spec().signature_name();
  auto iter = bundle->meta_graph_def.signature_def().find(signature_name);
  if (iter == bundle->meta_graph_def.signature_def().end()) {
    return errors::FailedPrecondition(strings::StrCat(
        "Serving signature key \"", signature_name, "\" not found."));
  }
  SignatureDef signature = iter->second;

  MakeModelSpec(request.model_spec().name(), signature_name,
                bundle.id().version, response->mutable_model_spec());

  std::vector<std::pair<string, Tensor>> input_tensors;
  std::vector<string> output_tensor_names;
  std::vector<string> output_tensor_aliases;
  TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
                                          &output_tensor_names,
                                          &output_tensor_aliases));
  std::vector<Tensor> outputs;
  RunMetadata run_metadata;
  TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
                                          output_tensor_names, {}, &outputs,
                                          &run_metadata));

  return PostProcessPredictionResult(signature, output_tensor_aliases, outputs,
                                     response);
}

此代码使用存储的模型运行预测。在我的例子中,这个存储的模型是RNN。

在实际执行预测的以下行之后:

TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
                                      output_tensor_names, {}, &outputs,
                                      &run_metadata));

我想将RNN的状态保存到文件/内存中,以便我可以在每次预测后的日期访问它们。我假设可以通过变量访问这些状态:

bundle->meta_graph_def

但目前尚不清楚如何具体访问RNN的状态值,然后将其保存到文件中。

1 个答案:

答案 0 :(得分:0)

您必须通过会话获取值,而不是事后。在你的行

bundle->session->Run(run_options, input_tensors,
                                  output_tensor_names, {}, &outputs,
                                  &run_metadata)

您不仅应该要求预测的结果,还要求状态张量。请注意,在对session->Run的特定调用之后,不存储状态的值,仅存储变量,RNN的状态是计算值,并在返回请求的结果后删除。

我不使用c ++接口,所以请原谅我在python中提供代码示例,希望它仍然有用(我敢打赌这不是你第一次不得不忍受阅读python示例),在python中,这看起来像:

prediction, state = sess.run([prediction_tensor, state_tensor], feed_dict=...)