更快的替代numpy.where?

时间:2015-10-22 13:15:19

标签: python numpy

我有一个3d数组,里面填充从0到N的整数。我需要一个与数组相等的索引列表1,2,3,... N.我可以用np.where作为如下:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

但这很慢。根据这个问题 fast python numpy where functionality? 应该可以加快索引搜索的速度,但是我无法将那里提出的方法转移到我获取实际索引的问题上。什么是加速上述代码的最佳方法?

作为一个附加组件:我想稍后存储索引,为此有意义的是使用np.ravel_multi_index来减小从保存3个索引到仅1的大小,即使用:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

更接近于Matlab的查找功能。这可以直接包含在不使用np.where的解决方案中吗?

4 个答案:

答案 0 :(得分:9)

我认为这个问题的标准向量化方法最终会占用大量内存 - 对于int64数据,它需要O(8 * N * data.size)字节,或者〜22 gig的内存,例如你上面给出了。我认为这不是一种选择。

您可以通过使用稀疏矩阵来存储唯一值的位置来取得一些进展。例如:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

利用稀疏矩阵构造函数中的快速代码以有用的方式组织数据,构造稀疏矩阵,其中行i仅包含展平数据等于i的索引。 / p>

为了测试它,我还将定义一个能够直接执行方法的函数:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

这两个函数为同一输入提供相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

稀疏方法比数据集的简单方法快一个数量级:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

稀疏方法的另一个好处是矩阵M最终是一种非常紧凑和有效的方式来存储所有相关信息供以后使用,如问题的附加部分所述。希望这很有用!

编辑:我意识到初始版本中存在一个错误:如果该范围内的任何值没有出现在数据中,则会失败:现在已修复此错误。

答案 1 :(得分:7)

我正在考虑这一点,并意识到使用Pandas groupby()解决这个问题的方法更为直观(但速度稍慢)。考虑一下:

import numpy as np
import pandas as pd

def get_indices_pandas(data):
    d = data.ravel()
    f = lambda x: np.unravel_index(x.index, data.shape)
    return pd.Series(d).groupby(d).apply(f)

这会从我之前的回答中返回与get_indices_simple相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_pandas(data_small)))
# True

这种Pandas方法比不太直观的矩阵方法略慢:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.2 s, sys: 665 ms, total: 14.8 s
# Wall time: 14.9 s

%time ind = get_indices_sparse(data)
# CPU times: user 842 ms, sys: 277 ms, total: 1.12 s
# Wall time: 1.12 s

%time ind = get_indices_pandas(data)
# CPU times: user 1.16 s, sys: 326 ms, total: 1.49 s
# Wall time: 1.49 s

答案 2 :(得分:4)

这是一种矢量化方法 -

# Mask of matches for data elements against all IDs from 1 to data.max()
mask = data == np.arange(1,data.max()+1)[:,None,None,None]

# Indices of matches across all IDs and their linear indices
idx = np.argwhere(mask.reshape(N,-1))

# Get cut indices where IDs shift
_,cut_idx = np.unique(idx[:,0],return_index=True)

# Cut at shifts to give us the final indx output
out = np.hsplit(idx[:,1],cut_idx[1:])

答案 3 :(得分:2)

基本上,对其他问题的大多数答案都有消息"使用间接排序"。

我们可以通过在展平数组上调用find来获得与i = [0..N]对应的线性索引(与MATLAB中的numpy.argsort类似):

flat = data.ravel()
lin_idx = np.argsort(flat, kind='mergesort')

然后我们得到一个大阵列;哪些指数属于i?我们只是根据每个i的计数来分割indices数组:

ans = np.split(lin_idx, np.cumsum(np.bincount(flat)[:-1]))

如果您仍然需要某处的3D索引,可以使用numpy.unravel_index