我正在尝试使用pytorch / libtorch实现一个简单的神经网络。以下示例改编自libtorch cpp frontend tutorial。
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
DeepQImpl(size_t N)
: linear1(2,5),
linear2(5,3) {}
torch::Tensor forward(torch::Tensor x) const {
x = torch::tanh(linear1(x));
x = linear2(x);
return x;
}
torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);
请注意,函数forward
被声明为const
。我正在编写的代码要求将NN评估为const函数,这对我来说似乎很合理。
这段代码不会编译。编译器抛出
错误:与“(const torch :: nn :: Linear)(at :: Tensor&)”的调用不匹配
x = linear1(x);
不过,通过将图层定义为mutable
,我找到了解决此问题的方法:
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
/* all the code */
mutable torch::nn:Linear linear1, linear2;
};
所以我的问题是
const
mutable
来解决此问题,并且安全吗?我的直觉是,在前向传递中,各层组装成可用于向后传播的结构,需要进行一些写操作。如果是这样,那么问题就变成了如何在第一步(非const
)中组装图层,然后在第二步(const
)中评估结构。