具有图的C ++逆向自动微分

时间:2018-08-14 12:43:29

标签: c++ automatic-differentiation

我正在尝试用C ++制作一个reverse mode automatic differentiation

我想到的想法是,对一个或两个其他变量进行运算的每个变量将把梯度保存在向量中。

这是代码:

class Var {
    private:
        double value;
        char character;
        std::vector<std::pair<double, const Var*> > children;

    public:
        Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
        void set_character(const char& character){ this->character = character; }

        // computes the derivative of the current object with respect to 'var'
        double gradient(Var* var) const{
            if(this==var){
                return 1.0;
            }

            double sum=0.0;
            for(auto& pair : children){
                // std::cout << "(" << this->character << " -> " <<  pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
                sum += pair.first*pair.second->gradient(var);
            }
            return sum;
        }

        friend Var operator+(const Var& l, const Var& r){
            Var result(l.value+r.value);
            result.children.push_back(std::make_pair(1.0, &l));
            result.children.push_back(std::make_pair(1.0, &r));
            return result;
        }

        friend Var operator*(const Var& l, const Var& r){
            Var result(l.value*r.value);
            result.children.push_back(std::make_pair(r.value, &l));
            result.children.push_back(std::make_pair(l.value, &r));
            return result;
        }

        friend std::ostream& operator<<(std::ostream& os, const Var& var){
            os << var.value;
            return os;
        }
};

我试图运行这样的代码:

int main(int argc, char const *argv[]) {
    Var x(5,'x'), y(6,'y'), z(7,'z');

    Var k = z + x*y;
    k.set_character('k');

    std::cout << "k = " << k << std::endl;
    std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
    std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
    std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;

    return 0;
}

应建立的计算图如下:

       x(5)   y(6)              z(7)
         \     /                 /
 ∂w/∂x=y  \   /  ∂w/∂y=x        /
           \ /                 /
          w=x*y               /
             \               /  ∂k/∂z=1
              \             /
      ∂k/∂w=1  \           /
                \_________/
                     |
                   k=w+z

然后,例如,如果我要计算∂k/∂x,则必须乘以边缘之后的梯度,并对每个边缘求和。这是由double gradient(Var* var) const递归完成的。所以我有∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x

问题

如果我这里有x*y之类的中间计算,则出问题了。当std::cout取消注释时,输出为:

k = 37
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂x = 0
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂y = 5
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂z = 1

它显示哪个变量连接到哪个变量,然后打印它们的地址以及连接的权重(应该是渐变)。

问题是 weight=0x之间的中间变量,中间变量保存着x*y的结果(我在其中将其表示为w我的图表)。 我不知道为什么这个为零,而不是另一个权重连接到y

我注意到的另一件事是,如果您像这样切换operator*中的行,则:

result.children.push_back(std::make_pair(1.0, &r));
result.children.push_back(std::make_pair(1.0, &l));

然后y个连接将取消。

在此先感谢您的帮助。

1 个答案:

答案 0 :(得分:4)

该行:

Var k = z + x*y;

调用operator*,返回一个Var临时变量,然后将其用于r的{​​{1}}参数,其中operator+存储地址临时的。该行结束后,pair个子项包括指向临时曾经的位置的指针,但该指针不再存在。


虽然它不能防止上述错误,但是您可以通过避免未命名的临时事件来创建预期的行为...

k

...程序所产生的...

Var xy = x * y;
xy.set_character('*');
Var k = z + xy;
k.set_character('k');

一个更好的解决方法可能是通过值 捕获孩子。


作为捕获此类错误的一般技巧,当您的程序似乎正在执行莫名其妙的操作(和/或崩溃)时,请尝试在诸如valgrind之类的内存错误检测器下运行它。对于您的代码,报告以以下内容开头:

k = 37
∂k/∂x = 6
∂k/∂y = 5
∂k/∂z = 1

捕获它的另一种方法可以是在析构函数中添加日志记录,以便您知道日志记录中提到的对象地址何时不再有效。