我需要在numpy数组“label”中找到最常见的元素,只要这些元素位于mask数组中。这是蛮力方法:
def getlabel(mask, label):
# get majority label
assert label.shape == mask.shape
tmp = []
for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
if mask[i][j] == True:
tmp.append(label[i][j])
return Counter(tmp).most_common(1)[0][0]
但我不认为这是最优雅,最快的方法。我应该使用哪些其他数据结构? (哈希,字典等......)?
答案 0 :(得分:1)
假设您的mask
是一个布尔数组:
import numpy as np
cnt = np.bincount(label[mask].flat)
这将为您提供值为0,1,2,... max(标签)的出现次数的向量
你可以找到最频繁的
most_frequent = np.argmax(cnt)
当然,输入数据中这些元素的数量是
cnt[most_frequent]
通常,np.bincount
很快。让我们尝试使用最大数量为999(即1000个分档)的标签和一个由8 000 000个值掩盖的10 000 000个元素阵列:
data = np.random.randint(0, 1000, (1000, 10000))
mask = np.random.random((1000, 10000)) < 0.8
# time this section
cnt = np.bincount(data[mask].flat)
使用我的机器需要80毫秒。 argmax
可能需要2 ns / bin,所以即使你的标签整数有点分散,也没关系。
如果满足以下条件,这种方法可能是最快的方法:
此解决方案可能适用于其他一些案例,但更多的问题是如何以及是否有更好的解决方案。 (请参阅metaperture
的答案。)例如,将Python列表简单转换为ndarray
代价相当昂贵,bincount
获得的速度优势将会丢失输入是一个Python列表,数据量不大。
整数空间中标签的稀疏性不是本身的问题。创建和归零输出向量相对较快,使用np.nonzero
进行压缩很容易,也很快。但是,如果最大标签值与输入数组的大小相比较大,那么速度优势可能会丢失。
答案 1 :(得分:1)
np.bincount
不一般方法。np.bincount
对于有界,低熵,离散分布更快。但是,它会失败:
对于一般解决方案,您应该执行以下操作之一:
cnt = Counter((l for m, l in zip(mask.flat, label.flat) if m)) # or...
cnt = Counter(label[mask].flat)
或者:
scipy.stats.mode(label[mask].flat)
在我的测试中,前者快了约20倍。如果你知道分布是离散的,具有相对较低的界限和熵,那么bincount会更快。
如果上述速度不够快,那么比bincount更好的通用方法是对数据进行采样
collections.Counter(np.random.choice(data[mask], 1000)).most_common(1)
scipy.stats.mode(np.random.choice(data[mask], 1000))
上述两种方法都比非采样版本快一个数量级,并且即使是最大的病态分布也能快速收敛到模式。