如何在C ++中将火炬模型定义为函数的输入

时间:2019-07-26 10:31:51

标签: c++ libtorch

我正在用c ++加载一个模型,该模型是经过python训练的。现在,我想编写一个函数,使用随机输入来测试模型,但无法将模型定义为该函数的参数。我已经尝试过struct,但是没有用。

void test(vector<struct comp*>& model){
    //pseudo input
    vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1,3,224, 224}));

    at::Tensor output = model[0]->forward(inputs).toTensor();
    cout << output << endl;
}

int main(int argc, char *argv[]) {

    if (argc == 2){
        cout << argv[1] << endl;
        //model = load_model(argv[1]);
        torch::jit::script::Module module = torch::jit::load(argv[1]);

    }
    else {
        cerr << "no path of model is given" << endl;
    }
    // test
    vector<struct comp*> modul;
    modul.push_back(module);
    test(modul);
}

1 个答案:

答案 0 :(得分:3)

编辑:您需要将module变量放入范围!

您的基本类型为torch::jit::script::Module,因此请为其定义名称:

using module_type = torch::jit::script::Module;

然后在您的代码中使用它,还对只读参数使用const引用:

void test(const vector<module_type>& model){
    //pseudo input
    vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1,3,224, 224}));

    at::Tensor output = model[0]->forward(inputs).toTensor();
    cout << output << endl;
}

int main(int argc, char *argv[]) {

    if (argc == 2){
        cout << argv[1] << endl;            
    }
    else {
        cerr << "no path of model is given" << endl;
        return -1;
    }

    // test
    module_type module = torch::jit::load(argv[1]);;
    vector<module_type> modul;
    modul.push_back(module);
    test(modul);
}