在多维数组上使用numpy.take?

时间:2012-08-08 16:21:44

标签: python numpy scipy

是否可以像花式索引一样使用多个轴?

多维数组相当大,所以我希望有可能获得加速。

例如:

import numpy as np
x = np.random.rand(20,20,20,20)
m = np.where(x>0.5)
m = (m[0],m[1],m[2])
print x[m].shape

1 个答案:

答案 0 :(得分:3)

您的代码:

m = np.where(x>0.5)
m = (m[0],m[1],m[2])
result = x[m]

可以通过使用repeat来编写以避免使用np.where:

m = np.sum(x>0.5,-1)
result = x.reshape(-1,x.shape[-1]).repeat(w.ravel(), 0)

这似乎快了4倍。不过我想知道你是不是要求

m = np.any(x>0.5,-1)
result = x[m,:]

这不会产生重复(虽然这里仍然需要重塑)?