pytorch / numpy中的部分切片具有任意和可变数量的尺寸

时间:2019-06-10 23:42:00

标签: python numpy pytorch

给定numpy(或pytorch)中的二维张量,我可以一次沿所有维度进行部分切片,如下所示:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

如果我在实现时不知道维数,那么无论张量中的维数如何,如何实现相同的切片模式? (即,如果a[1:]仅具有一个维度,a具有两个维度,a[1:,1:]具有三个维度,依此类推,以此类推,依此类推)

如果我能像下面这样在一行代码中做到这一点将是很好的,但这是无效的:

a[1:,1:,1:]

我对一种适用于pytorch张量的解决方案特别感兴趣(只是将火炬替换为上面的numpy,并且示例相同),但是我认为,如果该解决方案同时适用于numpy和pytorch,则可能是最好的。

1 个答案:

答案 0 :(得分:1)

答案: 制作slice个对象的元组可达到以下目的:

a[(slice(1,None),) * len(a.shape)]

说明: slice是一个内置的python类(不与numpy或pytorch绑定),为描述切片提供了下标符号的替代方法。 The answera different question建议使用此方法将切片信息存储在python变量中。 python glossary指出

  

方括号(下标)表示法在内部使用slice个对象。

由于numpy ndarrayspytorch tensors__getitem__方法支持切片的多维索引,因此它们还必须支持切片对象的多维索引,因此我们可以对这些对象进行元组化切成正确的长度。

顺便说一句,您可以通过如下创建一个虚拟类,然后对它进行切片来了解python如何使用切片对象:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))