使用col2im进行CNN反向传播

时间:2019-02-27 11:05:24

标签: c# machine-learning conv-neural-network caffe

我一直在研究卷积神经网络的实现。最初,我使用标准卷积来实现它们,该卷积按预期工作并通过了梯度检查,因此我开始使用im2col进行前向遍历和col2im进行后退遍历来优化卷积。我已经正确实现了im2col,并且正向传递可以按预期工作,但是,我注意到反向传递无法正常工作,主要是向后传播渐变。

问题似乎是,如https://stackoverflow.com/a/51717536http://www.programmersought.com/article/16963235/中所述,col2im函数与im2col并不完全相反,在我的情况下,这导致向后传递结果到尽管col2im和矩阵乘法看起来都可以正常工作,但完全不正确。

正因为如此,由于我正努力在网上找到很多与此有关的信息,因此如何解决此问题?

我的矩阵乘法代码,col2im代码和向后传递代码如下:

        /// <summary>
    /// O = (A dot B) + C + D
    /// </summary>
    /// <param name="a">NxM dimensional input matrix</param>
    /// <param name="b">MxP dimensional input matrix</param>
    /// <param name="c">Px1 dimensional optional input matrix</param>
    /// <param name="d">Nx1 dimensional optional input matrix</param>
    /// <param name="o">NxP dimensional output matrix</param>
    public static void Mad(Matrix a, Matrix b, Matrix c, Matrix d, Matrix o, bool reset)
    {
        if (a.Columns != b.Rows)
            throw new ArgumentException();

        if (a.Rows != o.Rows)
            throw new ArgumentException();
        if (b.Columns != o.Columns)
            throw new ArgumentException();

        if (c != null && c.Rows != b.Columns)
            throw new ArgumentException();

        if (d != null && d.Rows != a.Rows)
            throw new ArgumentException();

        Parallel.For(0, a.Rows, (i) =>
        //for(int i = 0; i < a.Rows; i++)
        {
            for (int j = 0; j < b.Columns; j++)
            {
                float acc = 0;
                for (int k = 0; k < a.Columns; k++)
                {
                    acc += a.memory[a.Index(i, k)] * b.memory[b.Index(k, j)];
                }

                if (reset)
                    o.memory[o.Index(i, j)] = acc + (c == null ? 0 : c.memory[c.Index(j, 0)]) + (d == null ? 0 : d.memory[d.Index(i, 0)]);
                else
                    o.memory[o.Index(i, j)] += acc + (c == null ? 0 : c.memory[c.Index(j, 0)]) + (d == null ? 0 : d.memory[d.Index(i, 0)]);
            }
        });
    }

public static void Column2Image(int input_sz, int input_cnt, int stride_len, int padding, int filter_sz, int output_sz, Matrix input, Matrix output)
    {
        //Foreach column in input, rearrange it into a block, stripping padding and applying appropriate strides
        int block_sz = filter_sz * filter_sz * input_cnt;
        int len = output_sz * output_sz;


        input.Clear();
        //Using a version based on Caffe for comparison

        //for (int c = 0; c < block_sz; c++)
        Parallel.For(0, block_sz, (c) =>
        {
            int col_off = (c % filter_sz);
            int row_off = (c / filter_sz) % filter_sz;
            int c_im = c / filter_sz / filter_sz;
            for (int row = 0; row < output_sz; row++)
                for (int col = 0; col < output_sz; col++)
                {
                    int row_pad = row * stride_len - padding + row_off;
                    int col_pad = col * stride_len - padding + col_off;
                    if (row_pad >= 0 && row_pad < input_sz && col_pad >= 0 && col_pad < input_sz)
                    {
                        input.memory[input.Index(c_im, row_pad * input_sz + col_pad)] += output.memory[c * output_sz * output_sz + row * output_sz + col];
                    }
                }
        });
    }

    public Matrix[] Propagate(Matrix[] prev_delta)
    {
        //cur_delta = BackwardDelta = Full convolution of prev_delta with 180 rotated filter <- sum from all filters in terms of filterCnt, but spread across inputDepth? 
        //Flatten and transpose the Weights as is, dot the flattened prev_delta
        Matrix.Mad(Weights.Transpose(), prev_delta[0].Reshape(filterCnt, outputSz * outputSz), null, null, BackwardDeltaIC, true);
        Matrix.Column2Image(inputSz, inputDepth, strideLen, paddingSz, filterSz, outputSz, BackwardDelta.Reshape(inputDepth, inputSz * inputSz), BackwardDeltaIC);
        //col2im the result
        return new Matrix[] { BackwardDelta };
    }

我尝试查看caffe的实现进行比较,但是我不知道他们如何处理它。任何信息将不胜感激!

编辑:我还应该提到,我认为问题部分与填充非零有关,这似乎也是https://github.com/pytorch/pytorch/blob/master/aten/src/THNN/generic/Col2Im.c简短讨论的内容,但我并不十分了解它们如何解决问题。

0 个答案:

没有答案