给定numpy(或pytorch)中的二维张量,我可以一次沿所有维度进行部分切片,如下所示:
>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> a[1:,1:]
array([[ 5, 6, 7],
[ 9, 10, 11]])
如果我在实现时不知道维数,那么无论张量中的维数如何,如何实现相同的切片模式? (即,如果a[1:]
仅具有一个维度,a
具有两个维度,a[1:,1:]
具有三个维度,依此类推,以此类推,依此类推)
如果我能像下面这样在一行代码中做到这一点将是很好的,但这是无效的:
a[1:,1:,1:]
我对一种适用于pytorch张量的解决方案特别感兴趣(只是将火炬替换为上面的numpy,并且示例相同),但是我认为,如果该解决方案同时适用于numpy和pytorch,则可能是最好的。
答案 0 :(得分:1)
答案: 制作slice个对象的元组可达到以下目的:
a[(slice(1,None),) * len(a.shape)]
说明:
slice
是一个内置的python类(不与numpy或pytorch绑定),为描述切片提供了下标符号的替代方法。 The answer至a different question建议使用此方法将切片信息存储在python变量中。 python glossary指出
方括号(下标)表示法在内部使用slice个对象。
由于numpy ndarrays和pytorch tensors的__getitem__
方法支持切片的多维索引,因此它们还必须支持切片对象的多维索引,因此我们可以对这些对象进行元组化切成正确的长度。
顺便说一句,您可以通过如下创建一个虚拟类,然后对它进行切片来了解python如何使用切片对象:
class A(object):
def __getitem__(self, ix):
return ix
print(A()[5]) # 5
print(A()[1:]) # slice(1, None, None)
print(A()[1:,1:]) # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)]) # (slice(1, None, None), slice(1, None, None))