我正在研究Android设备的TensorFlow git存储库中提供的示例。它使用Java接口作为C ++ API的包装器。有没有例子我可以直接使用C ++ API来初始化TensorFlow,加载模型和推理等等?
答案 0 :(得分:0)
签出this repo和following blog寻找解决方案。这些链接将提供有关如何在Android上使用Tensorflow c ++ API的逐步说明。这个想法是创建一个对Android友好的动态库(.so文件)(即不包括Tensorflow元素,这些元素仅与Desktop \ gpu兼容)。
答案 1 :(得分:0)
我为Raspberry PI编写了此代码,但我认为对于Android,它必须几乎相同:
tfbenchmark.h:
#ifndef TENSORFLOW_H
#define TENSORFLOW_H
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
class TensorFlowBenchmark
{
public:
TensorFlowBenchmark();
virtual ~TensorFlowBenchmark();
bool init();
bool run();
private:
std::unique_ptr<tensorflow::Session> session_;
};
#endif /* TENSORFLOW_H */
tfbenchmark.cpp:
#include "tfbenchmark.h"
#include <vector>
#include <fstream>
#include <chrono>
#include <ctime>
#include <cstddef>
#include <jpeglib.h>
#include <setjmp.h>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
// These are all common classes it's handy to reference with no namespace.
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
const static string root_dir = ".";
const static string image = "../input.jpg";
const static string graph = "models/frozen_graph.pb";
const static int32 input_width = 224;
const static int32 input_height = 224;
const static int32 input_mean = 128;
const static int32 input_std = 128;
const static string input_layer = "x_input_pl";
const static string output_layer = "out/out";
const static int NUM_EVAL = 100;
const static int MAX_BATCH = 256;
template<class T>
void report_metrics(const std::vector<T>& v, int batch_size) {
double sum = std::accumulate(v.begin(), v.end(), 0.0);
double mean = sum / v.size();
LOG(INFO) << "Batch size = " << batch_size << ": "
<< mean/batch_size << "ms per image";
}
// Error handling for JPEG decoding.
void CatchError(j_common_ptr cinfo) {
(*cinfo->err->output_message)(cinfo);
jmp_buf* jpeg_jmpbuf = reinterpret_cast<jmp_buf*>(cinfo->client_data);
jpeg_destroy(cinfo);
longjmp(*jpeg_jmpbuf, 1);
}
// Decompresses a JPEG file from disk.
Status LoadJpegFile(string file_name, std::vector<tensorflow::uint8>* data,
int* width, int* height, int* channels) {
struct jpeg_decompress_struct cinfo;
FILE* infile;
JSAMPARRAY buffer;
int row_stride;
if ((infile = fopen(file_name.c_str(), "rb")) == NULL) {
LOG(ERROR) << "Can't open " << file_name;
return tensorflow::errors::NotFound("JPEG file ", file_name,
" not found");
}
struct jpeg_error_mgr jerr;
jmp_buf jpeg_jmpbuf; // recovery point in case of error
cinfo.err = jpeg_std_error(&jerr);
cinfo.client_data = &jpeg_jmpbuf;
jerr.error_exit = CatchError;
if (setjmp(jpeg_jmpbuf)) {
return tensorflow::errors::Unknown("JPEG decoding failed");
}
jpeg_create_decompress(&cinfo);
jpeg_stdio_src(&cinfo, infile);
jpeg_read_header(&cinfo, TRUE);
jpeg_start_decompress(&cinfo);
*width = cinfo.output_width;
*height = cinfo.output_height;
*channels = cinfo.output_components;
data->resize((*height) * (*width) * (*channels));
row_stride = cinfo.output_width * cinfo.output_components;
buffer = (*cinfo.mem->alloc_sarray)((j_common_ptr) & cinfo, JPOOL_IMAGE,
row_stride, 1);
while (cinfo.output_scanline < cinfo.output_height) {
tensorflow::uint8* row_address =
&((*data)[cinfo.output_scanline * row_stride]);
jpeg_read_scanlines(&cinfo, buffer, 1);
memcpy(row_address, buffer[0], row_stride);
}
jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
fclose(infile);
return Status::OK();
}
// Given an image file name, read in the data, try to decode it as an image,
// resize it to the requested size, and then scale the values as desired.
Status FillTensorFromImageData(std::vector<tensorflow::uint8>& image_data,
const int batch_size, const int image_height,
const int image_width, const int image_channels,
std::vector<Tensor>* out_tensors) {
// In these loops, we convert the eight-bit data in the image into float,
// resize
// it using bilinear filtering, and scale it numerically to the float range
// that
// the model expects (given by input_mean and input_std).
tensorflow::Tensor image_tensor(
tensorflow::DT_FLOAT,
tensorflow::TensorShape(
{batch_size, image_height, image_width, image_channels}));
auto image_tensor_mapped = image_tensor.tensor<float, 4>();
LOG(INFO) << image_data.size() << "bytes in image_data";
tensorflow::uint8* in = image_data.data();
float* out = image_tensor_mapped.data();
for(int n = 0; n < batch_size; n++) {
for (int y = 0; y < image_height; ++y) {
tensorflow::uint8* in_row = in + (y * image_width * image_channels);
float* out_row = out + (n * image_height * image_width * image_channels) + (y * image_width * image_channels);
for (int x = 0; x < image_width; ++x) {
tensorflow::uint8* input_pixel = in_row + (x * image_channels);
float* out_pixel = out_row + (x * image_channels);
for (int c = 0; c < image_channels; ++c) {
out_pixel[c] =
static_cast<float>(input_pixel[c] - input_mean) / input_std;
}
}
}
}
out_tensors->push_back(image_tensor);
return Status::OK();
}
// Reads a model graph definition from disk, and creates a session object you
// can use to run it.
Status LoadGraph(string graph_file_name,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::GraphDef graph_def;
Status load_graph_status = ReadBinaryProto(tensorflow::Env::Default(),
graph_file_name, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return Status::OK();
}
TensorFlowBenchmark::TensorFlowBenchmark() {}
TensorFlowBenchmark::~TensorFlowBenchmark() {}
bool TensorFlowBenchmark::init() {
// We need to call this to set up global state for TensorFlow.
int argc;
char** argv;
tensorflow::port::InitMain("benchmark", &argc, &argv);
string graph_path = tensorflow::io::JoinPath(root_dir, graph);
Status load_graph_status = LoadGraph(graph_path, &session_);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
return false;
}
return true;
}
bool TensorFlowBenchmark::run() {
string image_path = tensorflow::io::JoinPath(root_dir, image);
std::vector<tensorflow::uint8> image_data;
int image_width;
int image_height;
int image_channels;
Status load_img_status = LoadJpegFile(image_path, &image_data, &image_width, &image_height,
&image_channels);
if(!load_img_status.ok()) {
LOG(ERROR) << load_img_status;
return false;
}
LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height << "x"
<< image_channels;
for(int batch_size = 1; batch_size <= MAX_BATCH; batch_size <<= 1) {
LOG(INFO) << "Batch size " << batch_size;
std::vector<Tensor> resized_tensors;
Status read_tensor_status =
FillTensorFromImageData(image_data, batch_size, image_height, image_width,
image_channels, &resized_tensors);
if (!read_tensor_status.ok()) {
LOG(ERROR) << read_tensor_status;
return false;
}
const Tensor& resized_tensor = resized_tensors[0];
// Actually run the image through the model.
std::vector<Tensor> outputs;
std::vector<long> timings;
for (int i = 0; i < NUM_EVAL; ++i) {
auto start = std::chrono::system_clock::now();
Status run_status = session_->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);
auto end = std::chrono::system_clock::now();
timings.push_back( std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return false;
}
}
report_metrics(timings, batch_size);
}
return true;
}