恢复NN前向通过的const正确性

时间:2019-04-18 18:48:32

标签: libtorch

我正在尝试使用pytorch / libtorch实现一个简单的神经网络。以下示例改编自libtorch cpp frontend tutorial

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    DeepQImpl(size_t N)
        : linear1(2,5),
          linear2(5,3) {}
    torch::Tensor forward(torch::Tensor x) const {
        x = torch::tanh(linear1(x));
        x = linear2(x);
        return x;
    }
    torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);

请注意,函数forward被声明为const。我正在编写的代码要求将NN评估为const函数,这对我来说似乎很合理。 这段代码不会编译。编译器抛出

  

错误:与“(const torch :: nn :: Linear)(at :: Tensor&)”的调用不匹配
  x = linear1(x);

不过,通过将图层定义为mutable,我找到了解决此问题的方法:

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    /* all the code */
    mutable torch::nn:Linear linear1, linear2;
};

所以我的问题是

  1. 为什么在张量上应用层而不是const
  2. 是否正在使用mutable来解决此问题,并且安全吗?

我的直觉是,在前向传递中,各层组装成可用于向后传播的结构,需要进行一些写操作。如果是这样,那么问题就变成了如何在第一步(非const)中组装图层,然后在第二步(const)中评估结构。

0 个答案:

没有答案