PyTorch c ++扩展中的可选张量

时间:2019-02-14 07:47:23

标签: c++ pytorch torch

我正在为pytorch编写C ++扩展,并使用c ++ api来实现。对于我的forward函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了此可选参数来做不同的事情。通常,对于C ++中的可选指针参数,我们使用NULL,然后在函数内部检查指针是否为NULL。我不知道如何针对at::Tensor类型的Torch C ++ API执行此操作。

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

请注意,我无法执行const at::Tensor optional_constraints = at::ones之类的操作,因为该参数可以采用任何实际值,并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有NULL个等效项?

2 个答案:

答案 0 :(得分:0)

因为我找不到类似的东西。 API中的OpenCV <link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet"/> <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script> <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js"></script> <link href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap-datepicker/1.6.4/css/bootstrap-datepicker.css" rel="stylesheet"/> <script src="https://cdnjs.cloudflare.com/ajax/libs/bootstrap-datepicker/1.6.4/js/bootstrap-datepicker.js"></script> <div class="container"> <div class="row"> <div class="col-sm-6"> <div class="input-group date"> <input type="text" id="card_form_plan_start_date" class="form-control timepicker" name="card_form_plan_start_date"> <input type="text" id="card_form_plan_end_date" class="form-control timepicker" name="card_form_plan_end_date"> <div class="input-group-addon"> <span class="glyphicon glyphicon-th"></span> </div> </div> </div> </div> </div>(基本上用于传递最佳矩阵,如蒙版),我建议您为此目的使用重载函数

noArray()

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2)
{
     // optional tensor wasnt passed
}

答案 1 :(得分:0)

一种可能是将std::optional用作std::optional<at::Tensor> optional_constraints = std::nullopt。它可以从上下文转换为bool,因此可以使用if (optional_constraints)进行检查。如果传递张量,请使用.value()方法来获取张量,否则默认值为std::nullopt