如何在pytorch关联中进行卷积?

时间:2019-04-03 16:21:17

标签: pytorch convolution

根据定义,离散卷积是关联的。但是,当我尝试在pytorch中验证这一点时,我找不到一个合理的结果。

关联律是$ f *(g * \ psi)=(f * g)* \ psi $,因此我创建了三个以零为中心的离散函数(作为张量),并将它们与适当的零填充进行卷积,以便所有获得结果图中的非零元素。

import torch
import torch.nn as nn

def test_conv_compst():
    # $\psi$
    inputs = torch.randn((1,4,7,7))
    # $g$
    a = torch.randn((7, 4, 3, 3))
    # $f$
    b = torch.randn((3, 7, 3, 3))
    int_1 = torch.conv2d(inputs, a, padding=2)
    # results obtained by the first order
    res_1 = torch.conv2d(int_1, b, padding=2)

    comp_k = torch.conv2d(a.transpose(1, 0), b, padding=2).transpose(1, 0)
    print(comp_k.shape)
    # results obtained through the second order
    res_2 = torch.conv2d(inputs, comp_k, padding=4)
    print(res_1.shape)
    print(res_2.shape)
    print(torch.max(torch.abs(res_2-res_1)))

预期结果是与两个结果的差异可以忽略不计。但它返回:

torch.Size([3, 4, 5, 5])
torch.Size([1, 3, 11, 11])
torch.Size([1, 3, 11, 11])
tensor(164.8044)

1 个答案:

答案 0 :(得分:0)

长话短说,这是因为批量处理。 torch.conv2d的第一个参数解释为[batch, channel, height, width],第二个参数解释为[out_channel, in_channel, height, width],输出解释为[batch, channel, height, width]。因此,如果调用conv2d(a, conv2d(b, c)),则将b的前导维视为批处理;如果调用conv2d(conv2d(a, b), c),则将其视为out_channels

话虽如此,我的印象是您在这里问数学,所以让我扩展一下。您的想法在理论上是正确的:卷积是线性算子,应该是关联的。但是,由于我们为他们提供了内核而不是代表线性运算符的实际矩阵,因此需要在幕后进行一些“转换”,以便将内核正确地解释为矩阵。传统上,这可以通过构造相应的circulant matrices(不包括边界条件)来完成。如果我们用abc表示内核,并用M表示循环矩阵创建运算符,我们得到的是M(a) @ [M(b) @ M(c)] = [M(a) @ M(b)] @ M(c),其中{{1} }表示矩阵矩阵乘法。

卷积实现返回一个图像(矢量,内核,不过您要称呼它),而不是相关的循环矩阵,这是可笑的冗余,在大多数情况下不适合内存。因此,我们还需要一些循环矢量运算符@,该运算符返回V(matrix)的第一列,因此是matrix的反函数。用抽象数学术语来说,诸如scipy.signal.convolve(实际上是M之类的函数,因为卷积需要对输入之一进行额外的翻转,为清楚起见,我将其跳过)实现为correlate,因此< / p>

convolve = lambda a, b: V(M(a) @ M(b))

我希望我不会失去你,这只是利用convolve(a, convolve(b, c)) = = V(M(a) @ M(V[M(b) @ M(c)]) = V(M(a) @ M(b) @ M(c)) = V(M(V[M(a) @ M(b)]) @ M(c)) = convolve(convolve(a, b), c) V的倒数以及矩阵乘法移动的关联性这一事实将彼此转换括号。请注意,中间行基本上是“原始” M。我们可以使用以下代码进行验证:

ABC

PyTorch的问题在于它将第一个输入解释为import numpy as np import scipy.signal as sig c2d = sig.convolve2d a = np.random.randn(7, 7) b = np.random.randn(3, 3) c = np.random.randn(3, 3) ab = c2d(a, b) ab_c = c2d(ab, c) bc = c2d(b, c) a_bc = c2d(a, bc) print((a_bc - ab_c).max()) ,第二个解释为[batch, channel, height, width]。这意味着第一个参数和第二个参数的“转换”运算符[out_channels, in_channels, height, width]不同。我们分别称它们为MM。由于只有一个输出,因此只有一个N,并且它可以是VM的倒数,但不能是两者的倒数(因为它们是不同的)。如果重写上面的等式时要注意区分NM,您将看到,根据您的选择,N是将一个求反还是将另一个求反。在第2行和第3行或第3行和第4行之间相等。

在实践中,还存在V维度的其他问题,这在卷积的经典定义中是不存在的,但是我的第一个猜测是可以使用单个提升运算符{{1 }}用于两个操作数,与批处理不同。