通过C ++运行frezeed.pb tensorflow.keras模型

时间:2019-02-04 20:22:33

标签: c++ tensorflow keras

我在训练后建立了tf.keras模型,将其冻结并将其保存在* .pb文件中。但是,我试图使用此代码在C ++中运行模型。

// example.cpp
#include <fstream>
#include <utility>
#include <vector>

// #include "tensorflow/cc/ops/image_ops.h"
// #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"

#include <iostream>

#include <stdio.h>
#include <opencv2/opencv.hpp>

using namespace std;
using namespace tensorflow;
using namespace cv;

int main()
{
    Mat image;
    image = imread("rgb_000000.png", 1);

    // Check if image has something
    if (!image.data)
    {
        std::cout << "Could find content in the image" << std::endl;
        return 0;
    }
    // get the start time to report
    auto start_total = std::chrono::high_resolution_clock::now();

    // Get dimensions
    unsigned int cv_img_h = image.rows;
    unsigned int cv_img_w = image.cols;
    unsigned int cv_img_d = image.channels();

    // Set up inputs to run the graph
    // tf tensor for feeding the graph
    tensorflow::Tensor x_pl(DT_FLOAT, {1, cv_img_h, cv_img_w, cv_img_d});

    // tf pointer for init of fake cv mat
    float *x_pl_pointer = x_pl.flat<float>().data();

    // fake cv mat (avoid copy)
    Mat x_pl_cv(cv_img_h, cv_img_w, CV_32FC3, x_pl_pointer);
    image.convertTo(x_pl_cv, CV_32FC3);

    // feed the input
    vector<std::pair<std::string, tensorflow::Tensor>> inputs = {{"input", x_pl}};

    // The session will initialize the outputs
    std::vector<tensorflow::Tensor> outputs;

    // Initialize a tensorflow session
    Session *session;
    tensorflow::SessionOptions options = SessionOptions();
    options.config.mutable_gpu_options()->set_allow_growth(true);
    Status status = NewSession(options, &session);
    if (!status.ok())
    {
        std::cout << status.ToString() << "\n";
        return 1;
    }
    cout << "Session successfully created.\n";

    // Read in the protobuf graph we exported
    // See https://stackoverflow.com/a/43639305/1076564 for other ways of saving and restoring Tensorflow graphs.
    cout<<"The model is loading\n";

    GraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), std::string("frozen_model.pb"), &graph_def);
    if (!status.ok())
    {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    cout<<"The model is loaded to the session\n";

    // Add the graph to the session
    status = session->Create(graph_def);
    if (!status.ok())
    {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    cout<< "The session created successfully.\n";

    // Run the session, evaluating our "c" operation from the graph
    status = session->Run(inputs, {"plant_output", "stem_output"}, {}, &outputs);
    if (!status.ok())
    {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    cout<<"\nAfter running the model\n";

    // (There are similar methods for vectors and matrices here:
    // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

    // Print the results
    // Process the output with map
    // Get output dimensions
    std::cout << outputs[0].shape().dim_size(0) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[0].shape().dim_size(1) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[0].shape().dim_size(2) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[0].shape().dim_size(3) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[1].shape().dim_size(0) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[1].shape().dim_size(1) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[1].shape().dim_size(2) << "\n"; // Tensor<type: float shape: [] values: 6>
    std::cout << outputs[1].shape().dim_size(3) << "\n"; // Tensor<type: float shape: [] values: 6>

    // Free any resources used by the session
    session->Close();
    return 0;
}

有关更多信息,请将该模型作为一个输入和两个输出。我正在使用tensorflow_cc来构建代码。 构建的代码没有任何问题,但是当我运行代码时,在创建会话后出现此错误。

Session successfully created.
The model is loading
The model is loaded to the session
2019-02-04 16:15:22.405227: E tensorflow/core/framework/op_kernel.cc:1197] OpKernel ('op: "MutableDenseHashTableV2" device_type: "CPU" constraint { name: "key_dtype" allowed_values { list { type: DT_STRING } } } constraint { name: "value_dtype" allowed_values { list { type: DT_INT64 } } }') for unknown op: MutableDenseHashTableV2
.
.
.
2019-02-04 16:15:22.428918: E tensorflow/core/framework/op_kernel.cc:1197] OpKernel ('op: "PopulationCount" device_type: "CPU" constraint { name: "T" allowed_values { list { type: DT_INT8 } } }') for unknown op: PopulationCount
2019-02-04 16:15:22.428941: E tensorflow/core/framework/op_kernel.cc:1197] OpKernel ('op: "PopulationCount" device_type: "CPU" constraint { name: "T" allowed_values { list { type: DT_UINT8 } } }') for unknown op: PopulationCount
Not found: Op type not registered 'ReadVariableOp' in binary running on ****-desktop. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.

我没听懂。

0 个答案:

没有答案