如何在C ++中提取火炬模型的输出?

时间:2020-06-02 18:21:46

标签: c++ machine-learning neural-network pytorch torch

我有训练有素的keras模型,并使用mmdnn对其进行了转换。然后我尝试在c ++代码中使用它:

#include <iostream>

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>

#include <torch.h>

int main()
{
    cv::Mat image;
    image= cv::imread("test_img.png", cv::IMREAD_GRAYSCALE);   // Read the file

try
{
    torch::jit::script::Module module;
    module = torch::jit::load("my_model.pth");

    torch::IntArrayRef input_dim = std::vector<int64_t>({ 1, 2, 256, 256});

    cv::Mat input_img;
    image.convertTo(input_img, CV_32FC3, 1 / 255.0);
    torch::Tensor x = torch::from_blob(input_img.data, { 1, 2, 256, 256 }, torch::kFloat);
    torch::NoGradGuard no_grad;

    auto output = module.forward({ x });

    float* data = static_cast<float*>(output.toTensor().data_ptr());

    cv::Mat output_img = cv::Mat(256, 256, CV_32FC3, data);
    cv::imwrite("output_img.png", output_img);
}
catch (std::exception &ex)
{
    std::cout << "exception! " << ex.what() << std::endl;
}

    return 0;
}

此代码引发异常:

例外! isTensor()内部评估失败 E:\ 20B \ pytorch \ pytorch \ aten \ src \ ATen / core / ivalue_inl.h:112,请 向PyTorch报告错误。预期Tensor但获得了Tuple(toTensor at E:\ 20B \ pytorch \ pytorch \ aten \ src \ ATen / core / ivalue_inl.h:112)(否 回溯可用)

在调用函数float* data = static_cast<float*>(output.toTensor().data_ptr());时,将其抛出到行toTensor()中。如果我使用toTuple()而不是toTensor(),则结果不具有函数data_ptr(),但是提取数据(并将其放入opencv图像中)需要使用此函数。

如何从模型输出中提取图像?

1 个答案:

答案 0 :(得分:1)

在这种情况下,模型的答案是2张图像的元组。我们可以通过以下方式提取它们:

torch::Tensor t0 = output.toTuple()->elements()[0].toTensor();
torch::Tensor t1 = output.toTuple()->elements()[1].toTensor();

变量t0t1包含带有模型输出的张量。

相关问题