将多维pytorch张量拆分为“ n”个较小的张量

时间:2020-05-15 03:29:56

标签: python pytorch tensor

假设我有一个5D张量,其形状如下:(1、3、10、40、1)。我想根据某个维度将其分为较小的相等张量(如果可能),并保留等于其他维度的 step 1 )。

例如,我想根据第四维(= 40 )对其进行拆分,其中每个张量的大小将等于 10 。因此,第一个 tensor_1 的值将为 0-> 9 tensor_2 的值将为 1-> 10 等等。

39个张量将具有以下形状:

Shape of tensor_1 : (1, 3, 10, 10, 1)
Shape of tensor_2 : (1, 3, 10, 10, 1)
Shape of tensor_3 : (1, 3, 10, 10, 1)
...    
Shape of tensor_39 : (1, 3, 10, 10, 1)

这是我尝试过的:

a = torch.randn(1, 3, 10, 40, 1)

chunk_dim = 10
a_split = torch.chunk(a, chunk_dim, dim=3)

这给了我4个张量。我该如何编辑它,以便像我解释的那样有39个张量,步长= 1?

2 个答案:

答案 0 :(得分:1)

您可以使用以下方法访问第i个拆分:

a[:,:,:,i:i+10,:]

例如,您示例中的tensor_3可以通过以下方式访问:

a[:,:,:,3:13,:]

如果需要创建这些拆分的副本列表,则可以运行循环并使用迭代器进行索引。

答案 1 :(得分:0)

这会创建我想要的重叠张量:

torch.unfold(dimension, size, step)