加速Numpy数组切片

时间:2014-03-13 13:38:39

标签: python numpy

我需要沿着两个维度切割一个中等大小的2d Numpy数组。例如,

import numpy as np
X = np.random.normal(loc=0, scale=1, size=(3000, 100))

从这个数组中,我需要选择大量的行和相当少的列,例如

row_idx = np.random.random_integers(0, 2999, 2500)
col_idx = np.random.random_integers(0, 99, 10)

现在,我通过以下命令执行此操作:

X.take(col_idx, axis=1).take(row_idx, axis=0)

我的电脑大约需要115μs。问题是我每次运行需要做几百万次这个步骤。

你认为有机会加快这个速度吗?

修改(其他信息): 我有一个矩阵X,它是nxk。 n行包含1xk向量。 有三组:有效组(V),左组(L)和右组(R)。此外,还有系数v0和v。我需要计算这个数量:http://goo.gl/KNoSl3(对不起,我不能发布图像)。问题中的公式选择左(右)集中的所有X行以及活动集中的所有列。

修改2

我发现了另一个小改进。

X.take(col_idx, axis=1, mode='clip').take(row_idx, axis=0, mode='clip')

快一点(我的机器上大约25%)。

2 个答案:

答案 0 :(得分:1)

让我们做一些我们制作满足我们的n维网格条件的一维索引数组。

def make_multi_index(arr, *inds):
    tmp = np.meshgrid(*inds, indexing='ij')
    idx = np.vstack([x.ravel() for x in tmp])
    return np.ravel_multi_index(idx, X.shape)

使用您的测试数组和原始案例作为参考:

%timeit X.take(col_idx, axis=1).take(row_idx, axis=0)
10000 loops, best of 3: 95.4 µs per loop

让我们使用这个函数来构建索引,保存它们,然后使用take来返回你想要的输出:

inds = make_multi_index(X, row_idx, col_idx)
tmp = np.take(X,inds).reshape(row_idx.shape[0], col_idx.shape[0])

np.allclose(tmp, X.take(col_idx, axis=1).take(row_idx, axis=0))
Out[128]: True

因此,建立我们的指数并保持它们似乎有效,现在是时间:

%timeit make_multi_index(X, row_idx, col_idx)
1000 loops, best of 3: 356 µs per loop

%timeit np.take(X,inds).reshape(row_idx.shape[0], col_idx.shape[0])
10000 loops, best of 3: 59.9 µs per loop

因为它发生的并不是非常好 - 这可能会随着你想要的更多维度而变得更好。无论如何,如果你保留这些索引超过10-15次迭代,它可以帮助一些或者如果你添加一个额外的维度并同时获取所有非活动数据集。

答案 1 :(得分:0)

您可以使用二维花式索引:

X[row_idx,col_idx[:,None]]

但是我的机器需要大约1ms,而使用你的方法大约需要300us。

除非您在row_idxcol_idx中提供有关值的其他信息,否则您的方法似乎是最好的。