如何添加数组列表(张量)

时间:2019-10-29 05:54:17

标签: python numpy tensorflow pytorch

我正在定义一个简单的conv2d函数,以计算输入和内核(均为二维张量)之间的互相关,如下所示:

import torch 

def conv2D(X, K):
    h = K.shape[0]
    w = K.shape[1]
    ĥ = X.shape[0] - h + 1
    ŵ = X.shape[1] - w + 1
    Y = torch.zeros((ĥ, ŵ)) 
    for i in range (ĥ):
        for j in range (ŵ):
            Y[i, j] = (X[i: i+h, j: j+w]*K).sum()

    return Y 

当X和K为3级张量时,我计算每个通道的conv2d,然后将它们加在一起,如下所示:

def conv2D_multiple(X, K):
    cross = []
    result = 0
    for x, k in zip(X, K):
        cross.append(conv2D(x,k))

    for t in cross:
        result += t

    return result 

要测试我的功能,

X_2 = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], 
                    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=torch.float32)
K_2 = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]], dtype=torch.float32)

conv2D_multiple(X_2, K_2)

结果是:

tensor([[ 56.,  72.],
        [104., 120.]])

结果符合预期,但是,我相信我的第二个 conv2D_multiple(X, K)函数内部的for循环是多余的。我的问题是如何求和(明智的选择) 列表中的张量(数组),因此我省略了第二个for循环。

1 个答案:

答案 0 :(得分:1)

由于您的conv2D对每个切片行为进行操作,因此您可以分配3D张量,以便在使用第一个for循环时,通过获取每个结果并填充来存储结果每个切片。然后,您可以使用张量上的PyTorch内置的torch.sum运算符沿切片的尺寸求和,以获得相同的结果。为了使其可口,我将切片尺寸设为dim=0。因此,将cross从最初的空列表替换为3D的Torch张量,以允许您存储中间结果,然后通过求和沿切片尺寸进行压缩。我们可以避免这样做,因为您的初始实现将中间结果存储为2D张量列表。为了简化操作,请转到3D并允许PyTorch沿切片轴求和。

这将要求您在循环之前首先为此3D张量定义正确的尺寸:

def conv2D_multiple(X, K):
    h = K.shape[1]
    w = K.shape[2]
    ĥ = X.shape[1] - h + 1
    ŵ = X.shape[2] - w + 1
    c = X.shape[0]
    cross = torch.zeros((c, ĥ, ŵ), dtype=torch.float32)
    for i, (x, k) in enumerate(zip(X, K)):
        cross[i] = conv2D(x,k)

    result = cross.sum(dim=0)
    return result

请注意,对于您在输入和内核之间进行迭代的每个切片,我们无需将其附加到新列表,而是将其直接放入中间张量的切片中。一旦存储了这些结果,就沿切片轴求和以最终将其压缩为所需的值。使用示例输入运行上面的新功能会产生相同的结果。


如果这不是您想要的结果,另一种方法是简单地获取创建的张量列表,使用torch.stack将所有张量堆叠在一起并求和,从而构建中间张量。默认情况下,它沿第一个轴(dim=0)堆叠:

def conv2D_multiple(X, K):
    cross = []
    result = 0
    for x, k in zip(X, K):
        cross.append(conv2D(x,k))

    cross = torch.stack(cross)
    result = cross.sum(dim=0)
    return result