我使用以下教程导出由估算工具创建的已保存模型: https://github.com/MtDersvan/tf_playground/blob/master/wide_and_deep_tutorial/wide_and_deep_basic_serving.md
我试图在c ++中加载这个模型。我设法创建了一个序列化张量流::使用c ++的例子。如何将其转换为单个Tensor?
本教程使用tf.contrib.util.make_tensor_proto(serialized, shape=[1]))
。什么是等效的C ++ API?
答案 0 :(得分:0)
这比在python中写一点复杂,这里的代码示例可能会有所帮助:
#include <gtest/gtest.h>
#include <grpc/grpc.h>
#include <grpc++/channel.h>
#include <grpc++/client_context.h>
#include <grpc++/create_channel.h>
#include <grpc++/security/credentials.h>
#include <grpc++/security/credentials.h>
#include <iostream>
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
#include "tensorflow_serving/apis/prediction_service.pb.h"
using namespace std;
namespace tensorflow {
namespace serving {
TEST(GRPCCppClient, TestPredict) {
shared_ptr<grpc::Channel> channel = grpc::CreateChannel("127.0.0.1:30355", grpc::InsecureChannelCredentials());
shared_ptr<PredictionService::Stub> stub_ = PredictionService::NewStub(channel);
tensorflow::TensorProto tensorProto;
tensorflow::TensorShapeProto tensorShapeProto;
PredictRequest predictRequest;
PredictResponse predictResponse;
tensorflow::TensorShapeProto_Dim* dim_0 = tensorShapeProto.add_dim();
dim_0->set_size(1);
tensorProto.mutable_tensor_shape()->CopyFrom(tensorShapeProto);
Example example;
/*Construct example..*/
std::string test_str;
example.SerializeToString(&test_str);
tensorProto.set_dtype(DT_STRING);
tensorProto.add_string_val(test_str);
(*predictRequest.mutable_inputs())["inputs"] = tensorProto;
predictRequest.mutable_model_spec()->set_name("default"); /*set your own model name*/
grpc::ClientContext context;
grpc::Status status = stub_->Predict(&context, predictRequest, &predictResponse);
if (!status.ok()) {
std::cout << "GetFeature rpc failed." << std::endl;
ASSERT_TRUE(false);
}
cout<<"------------\n"<<predictResponse.DebugString()<<endl;
}
}
}