从张量数组切片不均匀的列

时间:2019-05-08 06:45:04

标签: python python-3.x pytorch tensor tensor-indexing

我有一个像这样的数组:

([[[ 0,  1,  2],
 [ 3,  4,  5]],

[[ 6,  7,  8],
[ 9, 10, 11]],

[[12, 13, 14],
[15, 16, 17]]])

如果我想将数字12切为17,我会使用:

arr[2, 0:2, 0:3]

但是我将如何将数组切成12到16?

2 个答案:

答案 0 :(得分:2)

您首先需要“展平”最后两个维度。只有这样,您才能提取所需的元素:

xf = x.view(x.size(0), -1)  # flatten the last dimensions
xf[2, 0:5]
Out[87]: tensor([12, 13, 14, 15, 16])

答案 1 :(得分:0)

另一种方法是简单地将张量索引并切成所需的内容,如:

# input tensor 
t = tensor([[[ 0,  1,  2],
             [ 3,  4,  5]],

           [[ 6,  7,  8],
            [ 9, 10, 11]],

           [[12, 13, 14],
            [15, 16, 17]]])

# slice the last `block`, then flatten it and 
# finally slice all elements but the last one
In [10]: t[-1].view(-1)[:-1]   
Out[10]: tensor([12, 13, 14, 15, 16])

请注意,由于这是基本切片,因此会返回 视图 。因此,对切片部分进行任何更改也会影响原始张量。例如:

# assign it to some variable name
In [11]: sliced = t[-1].view(-1)[:-1] 
In [12]: sliced      
Out[12]: tensor([12, 13, 14, 15, 16])

# modify one element
In [13]: sliced[-1] = 23   

In [14]: sliced  
Out[14]: tensor([12, 13, 14, 15, 23])

# now, the original tensor is also updated
In [15]: t  
Out[15]: 
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 23, 17]]])