如何在C ++中为Tensorflow-Lite设置图像输入?

时间:2019-05-20 14:16:41

标签: c++ embedded-linux tensor tensorflow-lite

我正在尝试在嵌入式平台上使用C ++将Tensoflow模型从Python + Keras版本迁移到Tensorflow Lite。

似乎我不知道如何正确设置解释器输入。

输入形状应为(1、224、224、3)。

作为输入,我正在使用openCV拍摄图像,并将其转换为CV_BGR2RGB。


std::unique_ptr<tflite::FlatBufferModel> model_stage1 = 
tflite::FlatBufferModel::BuildFromFile("model1.tflite");
  TFLITE_MINIMAL_CHECK(model_stage1 != nullptr);

  // Build the interpreter
  tflite::ops::builtin::BuiltinOpResolver resolver_stage1;
  std::unique_ptr<Interpreter> interpreter_stage1;
  tflite::InterpreterBuilder(*model_stage1, resolver_stage1)(&interpreter_stage1);

TFLITE_MINIMAL_CHECK(interpreter_stage1 != nullptr);

  cv::Mat cvimg = cv::imread(imagefile);
  if(cvimg.data == NULL) {
    printf("=== IMAGE READ ERROR ===\n");
    return 0;
  }

  cv::cvtColor(cvimg, cvimg, CV_BGR2RGB);

  uchar* input_1 = interpreter_stage1->typed_input_tensor<uchar>(0);

 memcpy( ... );

对于这种uchar类型,我正确设置了memcpy有问题。

当我这样做时,我在工作中遇到段错误:

memcpy(input_1, cvimg.data, cvimg.total() * cvimg.elemSize());

在这种情况下,如何正确填写输入内容?

2 个答案:

答案 0 :(得分:1)

要将我的评论转换为答案: Memcpy在这里可能不是正确的方法。 OpenCV将图像保存为每像素RGB顺序(或BGR或另一种颜色组合)颜色值的一维数组。可以通过以下方式遍历这些RGB块:

for (const auto& rgb : cvimg) {
    // now rgb[0] is the red value, rgb[1] green and rgb[2] blue.
}

然后将值写入Tensorflow-Lite typed_input_tensor中;其中,i是索引(迭代器),x是分配的值:

interpreter->typed_input_tensor<uchar>(0)[i] = x;

所以循环看起来像这样:

for (size_t i = 0; size_t < cvimg.size(); ++i) {
    const auto& rgb = cvimg[i];
    interpreter->typed_input_tensor<uchar>(0)[3*i + 0] = rgb[0];
    interpreter->typed_input_tensor<uchar>(0)[3*i + 1] = rgb[1];
    interpreter->typed_input_tensor<uchar>(0)[3*i + 2] = rgb[2];
}

答案 1 :(得分:0)

这是至少在单通道情况下可以做到的方式。这假定opencv bufffer是连续的。因此,在这种情况下,张量暗度为(1,x,y,1)。

float* out = interpreter->typed_tensor<float>(input);
input_type = interpreter->tensor(input)->type;
img.convertTo(img, CV_32F, 255.f/input_std);
cv::subtract(img, cv::Scalar(input_mean/input_std), img);
float* in = img.ptr<float>(0);
memcpy(out, in, img.rows * img.cols * sizeof(float));

OpenCV版本-4.3.0 TF Lite版本-2.0.0

纳达的方法也是正确的。选择适合您的编程风格的任何一种,但是memcpy版本将相对更快。