根据定义,离散卷积是关联的。但是,当我尝试在pytorch中验证这一点时,我找不到一个合理的结果。
关联律是$ f *(g * \ psi)=(f * g)* \ psi $,因此我创建了三个以零为中心的离散函数(作为张量),并将它们与适当的零填充进行卷积,以便所有获得结果图中的非零元素。
import torch
import torch.nn as nn
def test_conv_compst():
# $\psi$
inputs = torch.randn((1,4,7,7))
# $g$
a = torch.randn((7, 4, 3, 3))
# $f$
b = torch.randn((3, 7, 3, 3))
int_1 = torch.conv2d(inputs, a, padding=2)
# results obtained by the first order
res_1 = torch.conv2d(int_1, b, padding=2)
comp_k = torch.conv2d(a.transpose(1, 0), b, padding=2).transpose(1, 0)
print(comp_k.shape)
# results obtained through the second order
res_2 = torch.conv2d(inputs, comp_k, padding=4)
print(res_1.shape)
print(res_2.shape)
print(torch.max(torch.abs(res_2-res_1)))
预期结果是与两个结果的差异可以忽略不计。但它返回:
torch.Size([3, 4, 5, 5])
torch.Size([1, 3, 11, 11])
torch.Size([1, 3, 11, 11])
tensor(164.8044)
答案 0 :(得分:0)
torch.conv2d
的第一个参数解释为[batch, channel, height, width]
,第二个参数解释为[out_channel, in_channel, height, width]
,输出解释为[batch, channel, height, width]
。因此,如果调用conv2d(a, conv2d(b, c))
,则将b
的前导维视为批处理;如果调用conv2d(conv2d(a, b), c)
,则将其视为out_channels
。
话虽如此,我的印象是您在这里问数学,所以让我扩展一下。您的想法在理论上是正确的:卷积是线性算子,应该是关联的。但是,由于我们为他们提供了内核而不是代表线性运算符的实际矩阵,因此需要在幕后进行一些“转换”,以便将内核正确地解释为矩阵。传统上,这可以通过构造相应的circulant matrices(不包括边界条件)来完成。如果我们用a
,b
,c
表示内核,并用M
表示循环矩阵创建运算符,我们得到的是M(a) @ [M(b) @ M(c)] = [M(a) @ M(b)] @ M(c)
,其中{{1} }表示矩阵矩阵乘法。
卷积实现返回一个图像(矢量,内核,不过您要称呼它),而不是相关的循环矩阵,这是可笑的冗余,在大多数情况下不适合内存。因此,我们还需要一些循环矢量运算符@
,该运算符返回V(matrix)
的第一列,因此是matrix
的反函数。用抽象数学术语来说,诸如scipy.signal.convolve
(实际上是M
之类的函数,因为卷积需要对输入之一进行额外的翻转,为清楚起见,我将其跳过)实现为correlate
,因此< / p>
convolve = lambda a, b: V(M(a) @ M(b))
我希望我不会失去你,这只是利用convolve(a, convolve(b, c)) =
= V(M(a) @ M(V[M(b) @ M(c)])
= V(M(a) @ M(b) @ M(c))
= V(M(V[M(a) @ M(b)]) @ M(c))
= convolve(convolve(a, b), c)
是V
的倒数以及矩阵乘法移动的关联性这一事实将彼此转换括号。请注意,中间行基本上是“原始” M
。我们可以使用以下代码进行验证:
ABC
PyTorch的问题在于它将第一个输入解释为import numpy as np
import scipy.signal as sig
c2d = sig.convolve2d
a = np.random.randn(7, 7)
b = np.random.randn(3, 3)
c = np.random.randn(3, 3)
ab = c2d(a, b)
ab_c = c2d(ab, c)
bc = c2d(b, c)
a_bc = c2d(a, bc)
print((a_bc - ab_c).max())
,第二个解释为[batch, channel, height, width]
。这意味着第一个参数和第二个参数的“转换”运算符[out_channels, in_channels, height, width]
不同。我们分别称它们为M
和M
。由于只有一个输出,因此只有一个N
,并且它可以是V
或M
的倒数,但不能是两者的倒数(因为它们是不同的)。如果重写上面的等式时要注意区分N
和M
,您将看到,根据您的选择,N
是将一个求反还是将另一个求反。在第2行和第3行或第3行和第4行之间相等。
在实践中,还存在V
维度的其他问题,这在卷积的经典定义中是不存在的,但是我的第一个猜测是可以使用单个提升运算符{{1 }}用于两个操作数,与批处理不同。