我试图将二阶张量转换成二进制三阶张量。给定二阶张量作为amxn numpy数组:A,我需要取A中的每个元素值:x,并用向量:v替换它,其中维度等于A的最大值,但值为1递增在v的索引处对应于值x(即v [x] = 1)。我一直在关注这个问题:Increment given indices in a matrix,它解决了在2维坐标给出的索引处产生增量的数组。我一直在阅读答案,并尝试使用np.ravel_multi_index()和np.bincount()来做同样的事情,但有3维坐标,但我继续得到一个ValueError:“坐标数组中的无效条目”。这就是我一直在使用的:
def expand_to_tensor_3(array):
(x, y) = array.shape
(a, b) = np.indices((x, y))
a = a.reshape(x*y)
b = b.reshape(x*y)
tensor_3 = np.bincount(np.ravel_multi_index((a, b, array.reshape(x*y)), (x, y, np.amax(array))))
return tensor_3
如果您知道这里有什么问题或者知道更好的方法来实现我的目标,那么两者都会非常有用,谢谢。
答案 0 :(得分:3)
您可以使用(A[:,:,np.newaxis] == np.arange(A.max()+1)).astype(int)
。
以下是演示:
In [52]: A
Out[52]:
array([[2, 0, 0, 2],
[3, 1, 2, 3],
[3, 2, 1, 0]])
In [53]: B = (A[:,:,np.newaxis] == np.arange(A.max()+1)).astype(int)
In [54]: B
Out[54]:
array([[[0, 0, 1, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0]],
[[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]],
[[0, 0, 0, 1],
[0, 0, 1, 0],
[0, 1, 0, 0],
[1, 0, 0, 0]]])
检查A
的几个元素:
In [55]: A[0,0]
Out[55]: 2
In [56]: B[0,0,:]
Out[56]: array([0, 0, 1, 0])
In [57]: A[1,3]
Out[57]: 3
In [58]: B[1,3,:]
Out[58]: array([0, 0, 0, 1])
表达式A[:,:,np.newaxis] == np.arange(A.max()+1)
使用broadcasting将A
的每个元素与np.arange(A.max()+1)
进行比较。对于单个值,这看起来像:
In [63]: 3 == np.arange(A.max()+1)
Out[63]: array([False, False, False, True], dtype=bool)
In [64]: (3 == np.arange(A.max()+1)).astype(int)
Out[64]: array([0, 0, 0, 1])
A[:,:,np.newaxis]
是A
的三维视图,其形状为(3,4,1)
。添加了额外维度,以便与np.arange(A.max()+1)
的比较广播到每个元素,从而得到形状为(3, 4, A.max()+1)
的结果。
通过一个微不足道的变化,这将适用于n维数组。使用省略号...
索引numpy数组意味着“所有其他维度”。所以
(A[..., np.newaxis] == np.arange(A.max()+1)).astype(int)
将n维数组转换为(n + 1)维数组,其中最后一个维度是A
中整数的二进制指示符。这是一个具有一维数组的例子:
In [6]: a = np.array([3, 4, 0, 1])
In [7]: (a[...,np.newaxis] == np.arange(a.max()+1)).astype(int)
Out[7]:
array([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0]])
答案 1 :(得分:0)
你可以这样做:
tensor_3 = np.bincount(np.ravel_multi_index((a, b, array.reshape(x*y)),
(x, y, np.amax(array) + 1)))
不同之处在于我在amax()
结果中加1,因为ravel_multi_index()
期望索引都严格小于维度,而不是小于或等于。
如果这是你想要的,我不能100%肯定;另一种使代码运行的方法是在mode='clip'
中指定mode='wrap'
或ravel_multi_index()
,这会做一些不同的事情,我猜测不太正确。但你可以尝试一下。