实施“完全卷积”以找到没有卷积层输入的梯度

时间:2020-01-14 20:54:18

标签: python numpy gradient convolution backpropagation

我一直在尝试对卷积层输入实施“完全卷积”。根据{{​​3}}的文章,它看起来像这样:

this

所以,我写了这个函数:

def full_convolve(filters, gradient):
    filters = np.ones((5,5))
    gradient = np.ones((8,8))
    result = list()
    output_shape = 12
    filter_r = filters.shape[0] - 1
    filter_c = filters.shape[1] - 1
    gradient_r = gradient.shape[0] - 1
    gradient_c = gradient.shape[1] - 1

    for i in range(0,output_shape):
        if (i <= filter_r):
            row_slice = (0, i + 1)
            filter_row_slice = ( 0 , i + 1)
        elif ( i > filter_r and i <= gradient_r):
            row_slice = (i - filter_r, i + 1)
            filter_row_slice = (0, i + 1)
        else: 
            rest = ((output_shape - 1) -  i )
            row_slice = (gradient_r  - rest, i + 1 )
            filter_row_slice = (0 ,rest + 1)
        for b in range(0,output_shape):
            if (b <= filter_c):
                col_slice = (0, b + 1)
                filter_col_slice = (0, b+1)
            elif (b > filter_c and b <= gradient_c):
                col_slice = (b - filter_c, b + 1)
                filter_col_slice = (0,b+1)
            else:
                rest = (output_shape - 1 ) - b 
                col_slice = (gradient_r - rest , b + 1)
                filter_col_slice = (0, rest + 1)
            r = np.sum(gradient[row_slice[0] : row_slice[1], col_slice[0] : col_slice[1]] * filters[filter_row_slice[0]: filter_row_slice[1], filter_col_slice[0]: filter_col_slice[1]])
            result.append(r)
    result = np.asarray(result).reshape(12,12)

我对此进行了测试,输出似乎正确(如果我正确地进行了“全卷积”):

[[ 1.  2.  3.  4.  5.  5.  5.  5.  4.  3.  2.  1.]
 [ 2.  4.  6.  8. 10. 10. 10. 10.  8.  6.  4.  2.]
 [ 3.  6.  9. 12. 15. 15. 15. 15. 12.  9.  6.  3.]
 [ 4.  8. 12. 16. 20. 20. 20. 20. 16. 12.  8.  4.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 4.  8. 12. 16. 20. 20. 20. 20. 16. 12.  8.  4.]
 [ 3.  6.  9. 12. 15. 15. 15. 15. 12.  9.  6.  3.]
 [ 2.  4.  6.  8. 10. 10. 10. 10.  8.  6.  4.  2.]
 [ 1.  2.  3.  4.  5.  5.  5.  5.  4.  3.  2.  1.]]

但是,我不喜欢所有这些手动检查以及if / else语句。我觉得有一种更好的方法可以在NumPy中实现(也许使用一些零填充或类似的东西)。谁能建议一种更好的方法?谢谢

0 个答案:

没有答案