我有一个非常大的2D numpy数组(~5e8值)。我使用scipy.ndimage.label
标记了该数组然后我想找到包含每个标签的扁平数组的随机索引。我可以这样做:
import numpy as np
from scipy.ndimage import label
base_array = np.random.randint(0, 5, (100000, 5000))
labeled_array, nlabels = label(base_array)
for label_num in xrange(1, nlabels+1):
indices = np.where(labeled_array.flat == label_num)[0]
index = np.random.choice(indices)
但是,这个数组的数据很慢。我还尝试将np.where
替换为:
indices = np.argwhere(labeled_array.flat == label).squeeze()
发现它变慢了。我怀疑布尔掩码是缓慢的部分。无论如何要加快速度,或者更好的方法来做到这一点。我将在我的实际应用程序中说,数组相当稀疏,填充量约为25%,但我没有使用scipy的稀疏数组函数。
答案 0 :(得分:1)
您怀疑为每个标签单独屏蔽是否昂贵是正确的,因为无论您如何操作,屏蔽将始终为O(n)。
我们可以通过标签进行调整,然后从每个相同标签的块中随机挑选来规避这一点。
由于标签是整数范围,我们可以通过使用scipy中提供的一些稀疏矩阵机制来使argsort比np.argsort
便宜。
由于我的机器没有大量的ram,我不得不缩小你的例子(因子4)。然后它会在大约5秒内运行。
import numpy as np
from scipy.ndimage import label
from scipy import sparse
def multi_randint(bins):
"""draw one random int from each range(bins[i], bins[i+1])"""
high = np.diff(bins)
n = high.size
pick = np.random.randint(0, 1<<30, (n,))
reject = np.flatnonzero(pick + (1<<30) % high >= (1<<30))
while reject.size:
npick = np.random.randint(0, 1<<30, (reject.size,))
rejrej = npick + (1<<30) % sizes[reject] >= (1<<30)
pick[reject] = npick
reject = reject[rejrej]
return bins[:-1] + pick % high
# build mock data, note that I had to shrink by 4x b/c memory
base_array = np.random.randint(0, 5, (50000, 2500), dtype=np.int8)
labeled_array, nlabels = label(base_array)
# build auxiliary sparse matrix
h = sparse.csr_matrix(
(np.ones(labeled_array.size, bool), labeled_array.ravel(),
np.arange(labeled_array.size+1, dtype=np.int32)),
(labeled_array.size, nlabels+1))
# conversion to csc argsorts the labels (but cheaper than argsort)
h = h.tocsc()
# draw
result = h.indices[multi_randint(h.indptr)]
# check result
assert len(set(labeled_array.ravel()[result])) == nlabels+1