使用libtorch将图层代码从python(pytorch)转换为C ++

时间:2019-12-12 16:43:08

标签: c++ pytorch libtorch

我正在尝试使用C ++中的libtorch实现EDSR模型。我有一个Python类,需要将其转换为C ++结构。我的Python代码是:

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_mean, sign):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.bias.data = float(sign) * torch.Tensor(rgb_mean)

        for params in self.parameters():
            params.requires_grad = False

我的C ++结构是:

struct MeanShift: torch::nn::Conv2d{
    MeanShift(torch::nn::Conv2dOptions opt): torch::nn::Conv2d(opt){

    }
};

如何获取平均移位层的权重和偏差?

0 个答案:

没有答案