找到掩码数组中最常见的元素

时间:2014-07-17 22:12:40

标签: python algorithm numpy data-structures

我需要在numpy数组“label”中找到最常见的元素,只要这些元素位于mask数组中。这是蛮力方法:

def getlabel(mask, label):
    # get majority label
    assert label.shape == mask.shape

    tmp = []
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i][j] == True:
                tmp.append(label[i][j])
    return Counter(tmp).most_common(1)[0][0]

但我不认为这是最优雅,最快的方法。我应该使用哪些其他数据结构? (哈希,字典等......)?

2 个答案:

答案 0 :(得分:1)

假设您的mask是一个布尔数组:

import numpy as np

cnt = np.bincount(label[mask].flat)

这将为您提供值为0,1,2,... max(标签)的出现次数的向量

你可以找到最频繁的

most_frequent = np.argmax(cnt)

当然,输入数据中这些元素的数量是

cnt[most_frequent]

通常,np.bincount很快。让我们尝试使用最大数量为999(即1000个分档)的标签和一个由8 000 000个值掩盖的10 000 000个元素阵列:

data = np.random.randint(0, 1000, (1000, 10000))
mask = np.random.random((1000, 10000)) < 0.8

# time this section
cnt = np.bincount(data[mask].flat)

使用我的机器需要80毫秒。 argmax可能需要2 ns / bin,所以即使你的标签整数有点分散,也没关系。

如果满足以下条件,这种方法可能是最快的方法:

  • 标签是0..N范围内的整数,其中N不大于输入数组的大小
  • 输入数据位于NumPy数组

此解决方案可能适用于其他一些案例,但更多的问题是如何以及是否有更好的解决方案。 (请参阅metaperture的答案。)例如,将Python列表简单转换为ndarray代价相当昂贵,bincount获得的速度优势将会丢失输入是一个Python列表,数据量不大。

整数空间中标签的稀疏性不是本身的问题。创建和归零输出向量相对较快,使用np.nonzero进行压缩很容易,也很快。但是,如果最大标签值与输入数组的大小相比较大,那么速度优势可能会丢失。

答案 1 :(得分:1)

np.bincount 一般方法。np.bincount对于有界,低熵,离散分布更快。但是,它会失败:

  • 如果分布是无界的,则使用的内存是无界的(对于任意小的输入数组,可以任意大)
  • 如果分布是连续的,则bincount的argmax不是模式(从技术上来说它是KDE的MAP,其中KDE是使用类似直方图的方法生成的)
  • 如果分布具有高熵/扩散,则np.bincount的基于bin的表示没有意义(赢了但不会失败)

对于一般解决方案,您应该执行以下操作之一:

cnt = Counter((l for m, l in zip(mask.flat, label.flat) if m)) # or...
cnt = Counter(label[mask].flat)

或者:

scipy.stats.mode(label[mask].flat)

在我的测试中,前者快了约20倍。如果你知道分布是离散的,具有相对较低的界限和熵,那么bincount会更快。

如果上述速度不够快,那么比bincount更好的通用方法是对数据进行采样

collections.Counter(np.random.choice(data[mask], 1000)).most_common(1)
scipy.stats.mode(np.random.choice(data[mask], 1000))

上述两种方法都比非采样版本快一个数量级,并且即使是最大的病态分布也能快速收敛到模式。