了解MATLAB的行为

时间:2015-01-09 01:37:29

标签: matlab signal-processing convolution

我正在对一些张量进行卷积。

以下是MATLAB中的小测试:

    ker= rand(3,4,2);
    a= rand(5,7,2);
    c=convn(a,ker,'valid');
    c11=sum(sum(a(1:3,1:4,1).*ker(:,:,1)))+sum(sum(a(1:3,1:4,2).*ker(:,:,2)));
    c(1,1)-c11  % not equal!

第三行使用convn执行N-D卷积,我想比较第一行的结果,convn的第一列与手动计算值。但是,与convn相比,我的计算方法并不相同。

那么MATLAB的convn背后是什么?我对张量卷积的理解是错误的吗?

2 个答案:

答案 0 :(得分:4)

是的,你对卷积的理解是错误的。你的c11公式不是卷积:你只需乘以匹配指数然后求和。它更像是一个点积运算(在张量调整到相同尺寸的张量上)。我将尝试从1维开始解释。

1维数组

输入conv([4 5 6], [2 3])会返回[8 22 27 18]。我发现在多项式的乘法方面最容易想到这一点:

(4 + 5x + 6x ^ 2)*(2 + 3x)= 8 + 22x + 27x ^ 2 + 18x ^ 3

使用每个数组的条目作为多项式的系数,乘以多项式,收集类似的项,并从系数中读出结果。 x的权力在这里是为了跟踪增加和增加的内容。注意,x ^ n的系数在第(n + 1)项中找到,因为x的幂从0开始,而索引以1开始。

二维数组

输入conv2([2 3; 3 1], [4 5 6; 0 -1 1])会返回矩阵

 8  22  27  18
12  17  22   9
 0  -3   2   1

同样,这可以解释为多项式的乘法,但现在我们需要两个变量:比如说x和y。 x ^ n y ^ m的系数在(m + 1,n + 1)项中找到。以上输出意味着

(2 + 3x + 3y + xy)*(4 + 5x + 6x ^ 2 + 0y-xy + x ^ 2y)= 8 + 22x + 27x ^ 2 + 18x ^ 3 + 12y + 17xy + 22x ^ 2y + 9x ^ 3y-3xy ^ 2 + 2x ^ 2y ^ 2 + x ^ 3y ^ 2

三维数组

同样的故事。您可以将条目视为变量x,y,z中多项式的系数。多项式乘以,乘积系数是卷积的结果。

'有效'参数

这只保留了卷积的中心部分:第二个因子的所有项所参与的那些系数。为了使其非空,第二个数组的尺寸应不大于第一个。 (这与默认设置不同,订单卷积数组无关紧要。)示例:

conv([4 5 6], [2 3])返回[22 27](与上面的1维示例相比)。这相当于

中的事实

(4 + 5x + 6x ^ 2)*(2 + 3x)= 8+ 22x + 27x ^ 2 + 18x ^ 3

粗体术语来自 2和3x的贡献。

答案 1 :(得分:4)

几乎让它正确无误。你的理解有两点错误:

  1. 您选择valid作为卷积标记。这意味着从卷积返回的输出具有其大小,因此当您使用内核扫描矩阵时,它必须舒适地适合矩阵本身。因此,第一个"有效"返回的输出实际上是用于矩阵位置(2,2,1)的计算。这意味着您可以在这个位置舒适地调整内核,这对应于输出的位置(1,1)。为了演示,使用上面的代码,这就是aker的样子:

    >> a
    
    a(:,:,1) =
    
    0.9930    0.2325    0.0059    0.2932    0.1270    0.8717    0.3560
    0.2365    0.3006    0.3657    0.6321    0.7772    0.7102    0.9298
    0.3743    0.6344    0.5339    0.0262    0.0459    0.9585    0.1488
    0.2140    0.2812    0.1620    0.8876    0.7110    0.4298    0.9400
    0.1054    0.3623    0.5974    0.0161    0.9710    0.8729    0.8327
    
    
    a(:,:,2) =
    
    0.8461    0.0077    0.5400    0.2982    0.9483    0.9275    0.8572
    0.1239    0.0848    0.5681    0.4186    0.5560    0.1984    0.0266
    0.5965    0.2255    0.2255    0.4531    0.5006    0.0521    0.9201
    0.0164    0.8751    0.5721    0.9324    0.0035    0.4068    0.6809
    0.7212    0.3636    0.6610    0.5875    0.4809    0.3724    0.9042
    
    >> ker
    
    ker(:,:,1) =
    
    0.5395    0.4849    0.0970    0.3418
    0.6263    0.9883    0.4619    0.7989
    0.0055    0.3752    0.9630    0.7988
    
    
    ker(:,:,2) =
    
    0.2082    0.4105    0.6508    0.2669
    0.4434    0.1910    0.8655    0.5021
    0.7156    0.9675    0.0252    0.0674
    

    正如您所看到的,在矩阵(2,2,1)中的a位置,ker可以很好地适应矩阵,如果您从卷积中回忆起来,它只是元素的总和 - 内核和位置(2,2,1)的矩阵子集之间的元素乘积与内核的大小相同(实际上,你需要对内核执行其他操作,我将为下一点保留一些内容 - 请参阅下面)。因此,您计算的系数实际上是(2,2,1)的输出,而不是(1,1,1)的输出。从它的要点来看,你已经知道了这一点,但我想把它放在那里以防万一你不知道。

  2. 您忘记了对于N-D卷积,您需要在每个维度中翻转遮罩。如果您记得1D卷积,则必须在水平方向翻转蒙版。翻转的意思是你只需按相反的顺序放置元素。例如,[1 2 3 4]的数组将变为[4 3 2 1]。在2D卷积中,您必须水平和垂直翻转。因此,您将获取矩阵的每一行并按相反顺序放置每一行,这与1D情况非常相似。在这里,您将每行视为一维信号并进行翻转。完成此操作后,您将获得此翻转结果,并将每个视为一维信号并再次进行翻转。

    现在,在您的3D情况下,您必须水平翻转,垂直和暂时。这意味着您需要独立地为矩阵的每个切片执行2D翻转,然后以3D方式抓取单个列并将其视为一维信号。在MATLAB语法中,您将获得ker(1,1,:),将其视为一维信号,然后翻转。您将对ker(1,2,:)ker(1,3,:)等重复此操作,直到完成第一个切片为止。请记住,我们不会进入第二片或任何其他片段并重复我们刚才所做的事情。因为您正在拍摄矩阵的3D部分,所以对于您提取的每个3D列,您固有地操作所有切片。因此,只需查看矩阵的第一个切片,因此在计算卷积之前需要对内核执行此操作:

    ker_flipped = flipdim(flipdim(flipdim(ker, 1), 2), 3);
    

    flipdim执行指定轴上的翻转。在我们的例子中,我们是垂直地做,然后取结果并水平地做,然后再次暂时做。然后,您将在总和中使用ker_flipped。请注意,翻转的顺序并不重要。 flipdim独立地对每个维度进行操作,因此只要您记得翻转所有维度,输出就会相同。


  3. 为了演示,这里使用convn输出的内容:

    c =
    
        4.1837    4.1843    5.1187    6.1535
        4.5262    5.3253    5.5181    5.8375
        5.1311    4.7648    5.3608    7.1241
    

    现在,要手动确定c(1,1)是什么,您需要在翻转的内核上进行计算:

    ker_flipped = flipdim(flipdim(flipdim(ker, 1), 2), 3);
    c11 = sum(sum(a(1:3,1:4,1).*ker_flipped(:,:,1)))+sum(sum(a(1:3,1:4,2).*ker_flipped(:,:,2)));
    

    我们得到的结果是:

    c11 =
    
        4.1837
    

    正如您所看到的,这将验证我们使用convn在MATLAB中完成的计算得到的结果。如果要比较更多精度数字,请使用format long并将它们进行比较:

    >> format long;
    >> disp(c11)
    
       4.183698205668000
    
    >> disp(c(1,1))
    
       4.183698205668001
    

    如您所见,除最后一个数字外,所有数字都相同。这归因于数字四舍五入。绝对确定:

    >> disp(abs(c11 - c(1,1)));
    
       8.881784197001252e-16
    

    ......我认为订单的差异或10 -16 足以让我表明他们是平等的,对吧?