如何为libtorch中的网络自定义后向功能? (我测试了一些代码,但是失败了。)

时间:2019-06-26 03:53:26

标签: c++ pytorch libtorch

我正在使用libtorch(C ++中的pytorch的前端)来传输网络。为了监视渐变,我创建了检查器来捕获向后的过程,但是它没有输入自定义的backward()函数。

我从Ubuntu16.04上的源代码编译libtorch。

checkBP类在下面。

#include <torch/torch.h>

#include <cstddef>
#include <cstdio>
#include <iostream>
#include <string>
#include <vector>

class checkBP: public torch::nn::Module
{
public:
    checkBP(int show, std::string label):show(show), label(label) {std::cout<<" created checkBP"<<std::endl; }

    torch::Tensor forward(torch::Tensor inputs)
    {
        std::cout<<"\n\n\n\n checkBP forward passed\n\n\n\n"<<std::endl;
        return inputs.clone();
    }
    torch::Tensor backward(torch::Tensor grad_output){
        std::cout<<" In checkBP backward."<<std::endl;
        auto grad_mean = grad_output.abs().mean().data<float>();
        float grad_mean_float = *grad_mean;

        if ( show == 1)
        {
            std::cout<<"!!!! checkBP grad of "<<label<<" is(show=1): "<<grad_mean_float<<std::endl;
        }

        return grad_output;
    }

public:
    int show;
    std::string label;
};

class Net_model: public nn::Module
{
public:
    Net_model(int n_input,
              int n_hidden,
              int n_out){
        fc1 = register_module("fc1_net",torch::nn::Linear(n_input, n_hidden));
        fc2 = register_module("fc2_net",torch::nn::Linear(n_hidden, n_out));
    }

    nn::Linear fc1{nullptr}, fc2{nullptr};
    torch::Tensor forward(torch::Tensor input){
        auto x = fc1->forward(input);
        x = checkBP(/*show*/1,"net_model_checkBP").forward(x);
        x = fc2->forward(x);
        return x;
    }
};

int main()
{

    int n_input = 10, n_hidden = 8, n_out = 2, n_sample =  3;
    torch::Tensor inputs = torch::ones({n_sample, n_input});
    torch::Tensor target = torch::zeros({n_sample, n_out});

    auto net = Net_model(n_input,n_hidden,n_out);
    std::cout<<net<<std::endl;
    torch::optim::SGD optimizer(
        net.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
    for( int i = 0; i < 10; ++ i)
    {

        auto res = net.forward(inputs);
        optimizer.zero_grad();
        auto loss = torch::mse_loss(res, target);
        loss.backward();
        optimizer.step();
    }

return 0;

}

在训练循环中:在执行“ loss.backward()”之后,似乎没有执行checkBP :: backward()。那么libtorch中backward的正确定义是什么?

谢谢您的建议!

0 个答案:

没有答案