我有3个形状为2xN的numpy阵列(N大,几百万),称它们为a1,a2,a3。然后我有另一个形状Nx3的数组,其行值指的是数组a1,a2,a3之一,称之为排列。这个排列数组看起来像: [[0,1,2], [1,2,0] [1,0,2] ......最多N行]
我想创建另外3个形状为2xN的3个numpy数组b1,b2,b3,它们具有原始a1,a2,a3的内容,但是它们的列已根据置换数组的行进行置换。
我尝试过堆叠3个阵列的花式索引,以及numpy.choose,但是我无法让它工作。 我正在寻找没有python循环的解决方案。任何帮助将不胜感激!
修改
只是为了澄清我展示了我试图做的python循环实现:
aa = np.dstack((a1, a2, a3))
bb = np.empty_like(aa)
for i, o in enumerate(permutations):
bb[:,i, np.arange(3)] = aa[:, i, o]
然后我会从bb。
中检索b1,b2,b3答案 0 :(得分:1)
使用fancy-indexing
,您可以 -
bb = aa[:,np.arange(N),permutations.T]
请注意,这将是(2,3,N)
的形状。因此,要选择b1
,b2
,b3
,您可以这样做:
b1,b2,b3 = bb[:,0,:], bb[:,1,:], bb[:,2,:]
或者如果你坚持bb
与发布的代码形状相同,你可以添加:
bb = bb.swapaxes(1,2)
这是使用线性索引,切片以及当然NumPy broadcasting
-
idx = permutations + 3*np.arange(N)[:,None]
bb = aa.reshape(2,-1)[:,idx].reshape(2,N,3)
这将创建一个与发布的循环代码形状相同的bb
。
运行时测试
In [189]: def original_app(aa,permutations):
...: bb = np.empty_like(aa)
...: for i, o in enumerate(permutations):
...: bb[:,i, np.arange(3)] = aa[:, i, o]
...: return bb
...:
...:
...: def linear_index_app(aa,permutations):
...: idx = permutations + 3*np.arange(N)[:,None]
...: return aa.reshape(2,-1)[:,idx].reshape(2,N,3)
...:
In [190]: # Setup input arrays
...: N = 10000
...: a1 = np.random.rand(2,N)
...: a2 = np.random.rand(2,N)
...: a3 = np.random.rand(2,N)
...:
...: permutations = np.random.randint(0,3,(N,3))
...: aa = np.dstack((a1, a2, a3))
In [191]: %timeit original_app(aa,permutations)
10 loops, best of 3: 128 ms per loop
In [192]: %timeit aa[:,np.arange(N),permutations.T]
1000 loops, best of 3: 972 µs per loop
In [193]: %timeit linear_index_app(aa,permutations)
1000 loops, best of 3: 1.02 ms per loop
所以,似乎fancy-indexing
是最好的一个!