如何在C ++中反向传播神经网络?

时间:2018-05-14 16:30:57

标签: c++ neural-network

我一直在尝试用c ++创建一个神经网络,而我的反向传播代码并没有按照我想要的方式工作。我有一个文本文档,告诉网络如何运作。我有它有2个输入神经元,1个隐藏层有4个神经元和2个输出神经元。我现在正在学习如何成为一个XOR门。我有它所以它需要网络的成本,乘以.55(缩放)和加权/减去权重/偏差'取决于输出与正确答案的接近程度,以及权重/偏差是否接近'是+或 - 。这是代码:

void Network::backProp(void)
{
double b = 0,a;
int loop,l;


for(loop=0;loop<4;loop++)
{
    //Adds up the cost of the data
    b = b + (pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2));
}
a=.55*b;
if(b>.01)
{
for(l=0;l<4;l++)
{
    if(round(results[2*l])!=key[4*l+2])
    {
        if(data[0] <= 0)
        {
            data[0] = data[0]+a; //(abs(data[0])/a);
        }
        else
        {
            data[0] = data[0]-a; //(abs(data[0])/a);
        }
        if(data[1] <= 0)
        {
            data[1] = data[1]+a; //(abs(data[1])/a);
        }
        else
        {
            data[1] = data[1]-a; //(abs(data[1])/a);
        }
        if(data[2] <= 0)
        {
            data[2] = data[2]+a; //(abs(data[2])/a);
        }
        else
        {
            data[2] = data[2]-a; //(abs(data[2])/a);
        }
        if(data[3] <= 0)
        {
            data[3] = data[3]+a; //(abs(data[3])/a);
        }
        else
        {
            data[3] = data[3]-a; //(abs(data[3])/a);
        }
        if(data[4] <= 0)
        {
            data[4] = data[4]+a; //(abs(data[4])/a);
        }
        else
        {
            data[4] = data[4]-a; //(abs(data[4])/a);
        }
        if(data[6] <= 0)
        {
            data[6] = data[6]+a; //(abs(data[6])/a);
        }
        else
        {
            data[6] = data[6]-a; //(abs(data[6])/a);
        }
        if(data[7] <= 0)
        {
            data[7] = data[7]+a; //(abs(data[7])/a);
        }
        else
        {
            data[7] = data[7]-a; //(abs(data[7])/a);
        }
        if(data[8] <= 0)
        {
            data[8] = data[8]+a; //(abs(data[8])/a);
        }
        else
        {
            data[8] = data[8]-a; //(abs(data[8])/a);
        }
        if(data[9] <= 0)
        {
            data[9] = data[9]+a; //(abs(data[9])/a);
        }
        else
        {
            data[9] = data[9]-a; //(abs(data[9])/a);
        }
        if(data[10] <= 0)
        {
            data[10] = data[10]+a; //(abs(data[10])/a);
        }
        else
        {
            data[10] = data[10]-a; //(abs(data[10])/a);
        }
        if(data[11] <= 0)
        {
            data[11] = data[11]+a; //(abs(data[11])/a);
        }
        else
        {
            data[11] = data[11]-a; //(abs(data[11])/a);
        }
        if(data[12] <= 0)
        {
            data[12] = data[12]+a; //(abs(data[12])/a);
        }
        else
        {
            data[12] = data[12]-a; //(abs(data[12])/a);
        }
        if(data[13] <= 0)
        {
            data[13] = data[13]+a; //(abs(data[13])/a);
        }
        else
        {
            data[13] = data[13]-a; //(abs(data[13])/a);
        }
        if(data[14] <= 0)
        {
            data[14] = data[14]+a; //(abs(data[14])/a);
        }
        else
        {
            data[14] = data[14]-a; //(abs(data[14])/a);
        }
        if(data[16] <= 0)
        {
            data[16] = data[16]+a; //(abs(data[16])/a);
        }
        else
        {
            data[16] = data[16]-a; //(abs(data[16])/a);
        }
        if(data[18] <= 0)
        {
            data[18] = data[18]+a; //(abs(data[18])/a);
        }
        else
        {
            data[18] = data[18]-a; //(abs(data[18])/a);
        }
        if(data[20] <= 0)
        {
            data[20] = data[20]+a; //(abs(data[20])/a);
        }
        else
        {
            data[20] = data[20]-a; //(abs(data[20])/a);
        }
    }
    else
    {
        if(data[0] <= 0)
        {
            data[0] = data[0]-a; //(abs(data[0])/a);
        }
        else
        {
            data[0] = data[0]+a; //(abs(data[0])/a);
        }
        if(data[1] <= 0)
        {
            data[1] = data[1]-a; //(abs(data[1])/a);
        }
        else
        {
            data[1] = data[1]+a; //(abs(data[1])/a);
        }
       if(data[2] <= 0)
       {
            data[2] = data[2]-a; //(abs(data[2])/a);
       }
        else
        {
            data[2] = data[2]+a; //(abs(data[2])/a);
        }
        if(data[3] <= 0)
        {
            data[3] = data[3]-a; //(abs(data[3])/a);
        }
        else
        {
            data[3] = data[3]+a; //(abs(data[3])/a);
        }
        if(data[4] <= 0)
        {
            data[4] = data[4]-a; //(abs(data[4])/a);
        }
        else
        {
            data[4] = data[4]+a; //(abs(data[4])/a);
        }
        if(data[6] <= 0)
        {
            data[6] = data[6]-a; //(abs(data[6])/a);
        }
        else
        {
            data[6] = data[6]+a; //(abs(data[6])/a);
        }
        if(data[7] <= 0)
        {
            data[7] = data[7]-a; //(abs(data[7])/a);
        }
        else
        {
            data[7] = data[7]+a; //(abs(data[7])/a);
        }
        if(data[8] <= 0)
        {
            data[8] = data[8]-a; //(abs(data[8])/a);
        }
        else
        {
            data[8] = data[8]+a; //(abs(data[8])/a);
        }
        if(data[9] <= 0)
        {
            data[9] = data[9]-a; //(abs(data[9])/a);
        }
        else
        {
            data[9] = data[9]+a; //(abs(data[9])/a);
        }
        if(data[10] <= 0)
        {
            data[10] = data[10]-a; //(abs(data[10])/a);
        }
        else
        {
            data[10] = data[10]+a; //(abs(data[10])/a);
        }
        if(data[11] <= 0)
        {
            data[11] = data[11]-a; //(abs(data[11])/a);
        }
        else
        {
            data[11] = data[11]+a; //(abs(data[11])/a);
        }
        if(data[12] <= 0)
        {
            data[12] = data[12]-a; //(abs(data[12])/a);
        }
        else
        {
            data[12] = data[12]+a; //(abs(data[12])/a);
        }
        if(data[13] <= 0)
        {
            data[13] = data[13]-a; //(abs(data[13])/a);
        }
        else
        {
            data[13] = data[13]+a; //(abs(data[13])/a);
        }
        if(data[14] <= 0)
        {
            data[14] = data[14]-a; //(abs(data[14])/a);
        }
        else
        {
            data[14] = data[14]+a; //(abs(data[14])/a);
        }
        if(data[16] <= 0)
        {
            data[16] = data[16]-a; //(abs(data[16])/a);
        }
        else
        {
            data[16] = data[16]+a; //(abs(data[16])/a);
        }
        if(data[18] <= 0)
        {
            data[18] = data[18]-a; //(abs(data[18])/a);
        }
        else
        {
            data[18] = data[18]+a; //(abs(data[18])/a);
        }
        if(data[20] <= 0)
        {
            data[20] = data[20]-a; //(abs(data[20])/a);
        }
        else
        {
            data[20] = data[20]+a; //(abs(data[20])/a);
        }
    }
    if(round(results[2*l+1])!=key[4*l+3])
    {
        if(data[0] <= 0)
        {
            data[0] = data[0]+a; //(abs(data[0])/a);
        }
        else
        {
            data[0] = data[0]-a; //(abs(data[0])/a);
        }
        if(data[1] <= 0)
        {
            data[1] = data[1]+a; //(abs(data[1])/a);
        }
        else
        {
            data[1] = data[1]-a; //(abs(data[1])/a);
        }
        if(data[2] <= 0)
        {
            data[2] = data[2]+a; //(abs(data[2])/a);
        }
        else
        {
            data[2] = data[2]-a; //(abs(data[2])/a);
        }
        if(data[3] <= 0)
        {
            data[3] = data[3]+a; //(abs(data[3])/a);
        }
        else
        {
            data[3] = data[3]-a; //(abs(data[3])/a);
        }
        if(data[4] <= 0)
        {
            data[4] = data[4]+a; //(abs(data[4])/a);
        }
        else
        {
            data[4] = data[4]-a; //(abs(data[4])/a);
        }
        if(data[5] <= 0)
        {
            data[5] = data[5]+a; //(abs(data[5])/a);
        }
        else
        {
            data[5] = data[5]-a; //(abs(data[5])/a);
        }
        if(data[7] <= 0)
        {
            data[7] = data[7]+a; //(abs(data[7])/a);
        }
        else
        {
            data[7] = data[7]-a; //(abs(data[7])/a);
        }
        if(data[8] <= 0)
        {
            data[8] = data[8]+a; //(abs(data[8])/a);
        }
        else
        {
            data[8] = data[8]-a; //(abs(data[8])/a);
        }
        if(data[9] <= 0)
        {
            data[9] = data[9]+a; //(abs(data[9])/a);
        }
        else
        {
            data[9] = data[9]-a; //(abs(data[9])/a);
        }
        if(data[10] <= 0)
        {
            data[10] = data[10]+a; //(abs(data[10])/a);
        }
        else
        {
            data[10] = data[10]-a; //(abs(data[10])/a);
        }
        if(data[11] <= 0)
        {
            data[11] = data[11]+a; //(abs(data[11])/a);
        }
        else
        {
            data[11] = data[11]-a; //(abs(data[11])/a);
        }
        if(data[12] <= 0)
        {
            data[12] = data[12]+a; //(abs(data[12])/a);
        }
        else
        {
            data[12] = data[12]-a; //(abs(data[12])/a);
        }
        if(data[13] <= 0)
        {
            data[13] = data[13]+a; //(abs(data[13])/a);
        }
        else
        {
            data[13] = data[13]-a; //(abs(data[13])/a);
        }
        if(data[15] <= 0)
        {
            data[15] = data[15]+a; //(abs(data[15])/a);
        }
        else
        {
            data[15] = data[15]-a; //(abs(data[15])/a);
        }
        if(data[17] <= 0)
        {
            data[17] = data[17]+a; //(abs(data[17])/a);
        }
        else
        {
            data[17] = data[17]-a; //(abs(data[17])/a);
        }
        if(data[19] <= 0)
        {
            data[19] = data[19]+a; //(abs(data[19])/a);
        }
        else
        {
            data[19] = data[19]-a; //(abs(data[19])/a);
        }
        if(data[21] <= 0)
        {
            data[21] = data[21]+a; //(abs(data[21])/a);
        }
        else
        {
            data[21] = data[21]-a; //(abs(data[21])/a);
        }
    }
    else
    {
        if(data[0] <= 0)
        {
            data[0] = data[0]-a; //(abs(data[0])/a);
        }
        else
        {
            data[0] = data[0]+a; //(abs(data[0])/a);
        }
        if(data[1] <= 0)
        {
            data[1] = data[1]-a; //(abs(data[1])/a);
        }
        else
        {
            data[1] = data[1]+a; //(abs(data[1])/a);
        }
       if(data[2] <= 0)
       {
            data[2] = data[2]-a; //(abs(data[2])/a);
       }
        else
        {
            data[2] = data[2]+a; //(abs(data[2])/a);
        }
        if(data[3] <= 0)
        {
            data[3] = data[3]-a; //(abs(data[3])/a);
        }
        else
        {
            data[3] = data[3]+a; //(abs(data[3])/a);
        }
        if(data[4] <= 0)
        {
            data[4] = data[4]-a; //(abs(data[4])/a);
        }
        else
        {
            data[4] = data[4]+a; //(abs(data[4])/a);
        }
        if(data[5] <= 0)
        {
            data[5] = data[5]-a; //(abs(data[5])/a);
        }
        else
        {
            data[5] = data[5]+a; //(abs(data[5])/a);
        }
        if(data[7] <= 0)
        {
            data[7] = data[7]-a; //(abs(data[7])/a);
        }
        else
        {
            data[7] = data[7]+a; //(abs(data[7])/a);
        }
        if(data[8] <= 0)
        {
            data[8] = data[8]-a; //(abs(data[8])/a);
        }
        else
        {
            data[8] = data[8]+a; //(abs(data[8])/a);
        }
        if(data[9] <= 0)
        {
            data[9] = data[9]-a; //(abs(data[9])/a);
        }
        else
        {
            data[9] = data[9]+a; //(abs(data[9])/a);
        }
        if(data[10] <= 0)
        {
            data[10] = data[10]-a; //(abs(data[10])/a);
        }
        else
        {
            data[10] = data[10]+a; //(abs(data[10])/a);
        }
        if(data[11] <= 0)
        {
            data[11] = data[11]-a; //(abs(data[11])/a);
        }
        else
        {
            data[11] = data[11]+a; //(abs(data[11])/a);
        }
        if(data[12] <= 0)
        {
            data[12] = data[12]-a; //(abs(data[12])/a);
        }
        else
        {
            data[12] = data[12]+a; //(abs(data[12])/a);
        }
        if(data[13] <= 0)
        {
            data[13] = data[13]-a; //(abs(data[13])/a);
        }
        else
        {
            data[13] = data[13]+a; //(abs(data[13])/a);
        }
        if(data[15] <= 0)
        {
            data[15] = data[15]-a; //(abs(data[15])/a);
        }
        else
        {
            data[15] = data[15]+a; //(abs(data[15])/a);
        }
        if(data[17] <= 0)
        {
            data[17] = data[17]-a; //(abs(data[17])/a);
        }
        else
        {
            data[17] = data[17]+a; //(abs(data[17])/a);
        }
        if(data[19] <= 0)
        {
            data[19] = data[19]-a; //(abs(data[19])/a);
        }
        else
        {
            data[19] = data[19]+a; //(abs(data[19])/a);
        }
        if(data[21] <= 0)
        {
            data[21] = data[21]-a; //(abs(data[21])/a);
        }
        else
        {
            data[21] = data[21]+a; //(abs(data[21])/a);
        }
    }
 }
 }
}

我知道这很乱,但这就是我提出来的。我可以发布其余的代码,如果这会有所帮助。

2 个答案:

答案 0 :(得分:0)

将此块移动到单独的功能:

    if(data[0] <= 0)
    {
        data[0] = data[0]+a; //(abs(data[0])/a);
    }
    else
    {
        data[0] = data[0]-a; //(abs(data[0])/a);
    }
像这样:(找到合适的名字)

void AddAtoData(int& data, a)
{
    if(data <= 0)
    {
        data += a;
    }
    else
    {
        data -= a;
    }
}

然后将您的data []结构拆分为逻辑单元,例如您的图层以避免其他if-logic和使用循环。

清理完毕后,查看问题是否仍然存在,如果是,请回来。

答案 1 :(得分:0)

以下是代码的简化版

void Network::backProp(void)
{
    double b = 0,a;
    int loop,l;
    int inclusion1 [] = {0,1,2,3,4,6,7,8,9,10,11,12,13,14,16,18,20};
    int inclusion2 [] = {0,1,2,3,4,5,7,8,9,10,11,12,13,15,17,19,21};
    int j = 0;
    for(loop=0;loop<4;loop++)
    {
    //Adds up the cost of the data
        b = b + (pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2));
    }
    a=.55*b;

    if(b>.01)
    {
        for(l=0;l<4;l++)
        {
            for(j=0;j<17;j++)
            {
                if(round(results[2*l])!=key[4*l+2])
                {
                    data[inclusion1[j]] = data[inclusion1[j]] - abs(data[inclusion1[j]])/data[inclusion1[j]]*a;
                }

                if(round(results[2*l+1])!=key[4*l+3])
                {
                    data[inclusion2[j]] = data[inclusion2[j]] + abs(data[inclusion2[j]])/data[inclusion2[j]]*a;
                }
            }
        }
    }
}

我看到的基本问题是你的修正变量b我认为它没有被准确定义

应该更加顺畅

b = b + pow((pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2)),1/2);