在Pytorch中在tensor [batch,channel,sequence,H,W]上运行conv2d

时间:2017-07-27 14:18:36

标签: python deep-learning pytorch

我正在制作视频帧数据,我将输入数据作为[批量,通道,帧序列,高度,重量]形式的张量(让[B,C,S,H,W]表示)因此,每个批次基本上由一个连续的帧序列组成。我基本上想要做的是在每个帧上运行一个编码器(由几个conv2d组成),即每个[C,H,W]并将其作为[B, C_output,S,H_output,W_output]。现在,conv2d期望输入为(N,C_in,H_in,W_in)形式。我想知道在不破坏5D张量内的顺序的情况下,最好的方法是什么。 到目前为止,我正在考虑采用以下方式:

>>> # B,C,seq,h,w
# 4,2, 5,  3,3 

>>> x = Variable(torch.rand(4,2,5,3,3))
>>> x.size() 
#torch.Size([4, 2, 5, 3, 3])
>>> x = x.permute(0,2,1,3,4)
>>> x.size() #expected = 4,5,2,3,3 B,seq,C,h,w
#torch.Size([4, 5, 2, 3, 3])
>>> x = x.contiguous().view(-1,2,3,3)
>>> x.size()
#torch.Size([20, 2, 3, 3])

然后在更新的x上运行conv2d(编码器)并重新整形。但我认为它不会保留张量的原始顺序。那么,我怎样才能实现目标呢?

1 个答案:

答案 0 :(得分:2)

你在做什么都完全没问题。它将保留订单。您可以通过可视化来验证这一点。

我很快建立了这个用于显示存储在4d张量中的图像(其中dim=0是批量)或5d张量(其中dim=0是批处理而dim=1是序列):< / p>

def custom_imshow(tensor):
    if tensor.dim() == 4:
        count = 1
        for i in range(tensor.size(0)):
            img = tensor[i].numpy()
            plt.subplot(1, tensor.size(0), count)
            img = img / 2 + 0.5     # unnormalize
            img = np.transpose(img, (1, 2, 0))
            count += 1
            plt.imshow(img)
            plt.axis('off')

    if tensor.dim() == 5:
        count = 1
        for i in range(tensor.size(0)):
            for j in range(tensor.size(1)):
                img = tensor[i][j].numpy()
                plt.subplot(tensor.size(0), tensor.size(1), count)
                img = img / 2 + 0.5  # unnormalize
                img = np.transpose(img, (1, 2, 0))

                plt.imshow(img)
                plt.axis('off')
                count +=1

假设我们使用CIFAR-10数据集(由32x32x3尺寸的图像组成)。

对于张量x

>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)

enter image description here

执行x.view(-1, 3, 32, 32)后:

 # x.size() -> torch.Size([4, 5, 3, 32, 32])
 >>> x = x.view(-1, 3, 32, 32)
 >>> x.size()
 torch.Size([20, 3, 32, 32])
 >>> custom_imshow(x)

enter image description here

如果你回到5d张量视图:

# x.size() -> torch.Size([20, 3, 32, 32])
>>> x.view(4, 5, 3, 32, 32)
>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)

enter image description here