我想“归零” n维数组中所有具有两个或多个匹配索引的位置的元素。在二维中,这实际上是np.fill_diagonal()
,但在考虑第三维时,该功能就不足。
下面是我想要做的暴力版本。有什么方法可以清理并使其在n个维度上起作用?
x = np.ones([3,3,3])
x[:,0,0] = 0
x[0,:,0] = 0
x[0,0,:] = 0
x[:,1,1] = 0
x[1,:,1] = 0
x[1,1,:] = 0
x[:,2,2] = 0
x[2,:,2] = 0
x[2,2,:] = 0
print(x)
答案 0 :(得分:2)
一种方法是np.einsum
:
>>> a = np.ones((4,4,4), int)
>>> for n in range(3):
... np.einsum(f"{'iijii'[n:n+3]}->ij", a)[...] = 0
...
>>> a
array([[[0, 0, 0, 0],
[0, 0, 1, 1],
[0, 1, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 1, 1],
[0, 0, 0, 0],
[1, 0, 0, 1],
[1, 0, 1, 0]],
[[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 0, 0, 0],
[1, 1, 0, 0]],
[[0, 1, 1, 0],
[1, 0, 1, 0],
[1, 1, 0, 0],
[0, 0, 0, 0]]])
一般(ND)情况:
>>> from string import ascii_lowercase
>>> from itertools import combinations
>>>
>>> a = np.ones((4,4,4,4), int)
>>> n = a.ndim
>>> ltrs = ascii_lowercase[:n-2]
>>> for I in combinations(range(n), 2):
... li = iter(ltrs)
... np.einsum(''.join('z' if k in I else next(li) for k in range(n)) + '->z' + ltrs, a)[...] = 0
...
>>> a
array([[[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 1, 0, 0]],
<snip>