如何将序列化的tensorflow :: Example转换为会话的输入

时间:2017-06-30 14:22:29

标签: tensorflow

我使用以下教程导出由估算工具创建的已保存模型: https://github.com/MtDersvan/tf_playground/blob/master/wide_and_deep_tutorial/wide_and_deep_basic_serving.md

我试图在c ++中加载这个模型。我设法创建了一个序列化张量流::使用c ++的例子。如何将其转换为单个Tensor? 本教程使用tf.contrib.util.make_tensor_proto(serialized, shape=[1]))。什么是等效的C ++ API?

1 个答案:

答案 0 :(得分:0)

这比在python中写一点复杂,这里的代码示例可能会有所帮助:

#include <gtest/gtest.h>
#include <grpc/grpc.h>
#include <grpc++/channel.h>
#include <grpc++/client_context.h>
#include <grpc++/create_channel.h>
#include <grpc++/security/credentials.h>
#include <grpc++/security/credentials.h>
#include <iostream>
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
#include "tensorflow_serving/apis/prediction_service.pb.h"

using namespace std;

namespace tensorflow {
namespace serving {

TEST(GRPCCppClient, TestPredict) {

    shared_ptr<grpc::Channel> channel = grpc::CreateChannel("127.0.0.1:30355", grpc::InsecureChannelCredentials());
    shared_ptr<PredictionService::Stub> stub_ = PredictionService::NewStub(channel);

    tensorflow::TensorProto tensorProto;
    tensorflow::TensorShapeProto tensorShapeProto;
    PredictRequest predictRequest;
    PredictResponse predictResponse;

    tensorflow::TensorShapeProto_Dim* dim_0 = tensorShapeProto.add_dim();

    dim_0->set_size(1);

    tensorProto.mutable_tensor_shape()->CopyFrom(tensorShapeProto);

    Example example;

    /*Construct example..*/
    std::string test_str;

    example.SerializeToString(&test_str);

    tensorProto.set_dtype(DT_STRING);
    tensorProto.add_string_val(test_str);

    (*predictRequest.mutable_inputs())["inputs"] = tensorProto;
    predictRequest.mutable_model_spec()->set_name("default"); /*set your own model name*/
    grpc::ClientContext context;
    grpc::Status status = stub_->Predict(&context, predictRequest, &predictResponse);
      if (!status.ok()) {
          std::cout << "GetFeature rpc failed." << std::endl;
          ASSERT_TRUE(false);
      }
    cout<<"------------\n"<<predictResponse.DebugString()<<endl;
}
}
}