我发现自己在几个不同的场景中一直面临着这个问题。所以我想在这里分享一下,看看是否有最佳解决方法。
假设我有一个大数组的X和另一个X大小相同的数组y,其上有x所属的标签。如下所示。
X = np.array(['obect1', 'object2', 'object3', 'object4', 'object5'])
y = np.array([0, 1, 1, 0, 2])
我想要的是构建一个使用标签作为键的字典/哈希以及所有对象的索引以及X中的那些标签作为项目 。所以在这种情况下,所需的输出将是:
{0: (array([0, 3]),), 1: (array([1, 2]),), 2: (array([4]),)}
请注意,实际上X上的内容并不重要,但为了完整起见,我将其包括在内。
现在,我对该问题的天真解决方案是遍历所有标签并使用np.where==label
来构建字典。更详细地说,我使用这个函数:
def get_key_to_indexes_dic(labels):
"""
Builds a dictionary whose keys are the labels and whose
items are all the indexes that have that particular key
"""
# Get the unique labels and initialize the dictionary
label_set = set(labels)
key_to_indexes = {}
for label in label_set:
key_to_indexes[label] = np.where(labels==label)
return key_to_indexes
所以现在我的问题的核心: 有没有办法做得更好?有没有一种自然的方法来解决这个使用numpy函数?是不是以某种方式误入歧途?
作为一个不太重要的横向问题:上述定义中解决方案的复杂性是多少?我相信解决方案的复杂性如下:
或者用文字表示标签的数量乘以在一组y大小中使用np.where
的复杂性加上从一个数组中创建一个集合的复杂性。这是对的吗?
Pd积。我找不到这个具体问题的相关帖子,如果你有改变标题的建议或任何我会感激的。
答案 0 :(得分:2)
如果您使用字典存储索引,则只需遍历一次:
from collections import defaultdict
def get_key_to_indexes_ddict(labels):
indexes = defaultdict(list)
for index, label in enumerate(labels):
indexes[label].append(index)
缩放看起来很像你已经分析了你的选项,因为它上面的函数是O(N),其中N是y
的大小,因为检查一个值是否在字典中是O (1)。
有趣的是,由于np.where
遍历的速度要快得多,只要标签数量很少,你的功能就会更快。当有许多不同的标签时,我看起来更快。
以下是功能扩展的方式:
蓝线是你的功能,红线是我的。线条样式表示不同标签的数量。 {10: ':', 100: '--', 1000: '-.', 10000: '-'}
。您可以看到我的功能相对独立于标签数量,而当您有许多标签时,您的功能会很快变慢。如果您的标签很少,那么与您的标签相比就更好了。
答案 1 :(得分:0)
我也在努力寻找解决这类问题的“numpythonic”方法。这是我提出的最好的方法,虽然需要更多的内存:
def get_key_to_indexes_dict(labels):
indices = numpy.argsort(labels)
bins = numpy.bincount(labels)
indices = numpy.split(indices, numpy.cumsum(bins[bins > 0][:-1]))
return dict(zip(numpy.unique(labels), indices))
答案 2 :(得分:0)
numpy_indexed包(免责声明:我是它的作者)可用于以完全向量化的方式解决此类问题,并且具有O(nlogn)最坏情况时间复杂度:
import numpy_indexed as npi
indices = np.arange(len(labels))
unique_labels, indices_per_label = npi.group_by(labels, indices)
请注意,对于此类功能的许多常见应用程序,例如计算组标签的总和或均值,更有效的是不计算索引的拆分列表,而是在npi中使用其功能;即,npi.group_by(labels).mean(some_corresponding_array),而不是循环遍历indices_per_label并取这些指数的均值。
答案 3 :(得分:0)
假设标签是连续整数[0, m]
并取n = len(labels)
,set(labels)
的复杂度为O(n),循环中np.where
的复杂度为O (m * n个)。但是,总体复杂度写为O(m * n)而不是O(m * n + n),请参阅"Big O notation" on wikipedia。
你可以做两件事来提高性能:1)使用更有效的算法(更低的复杂性)和2)用快速数组操作替换Python循环。
目前发布的其他答案正是这样做的,而且代码非常合理。然而,最佳解决方案将是完全矢量化的并且具有O(n)复杂度。这可以使用Scipy的某个较低级别函数来完成:
def sparse_hack(labels):
from scipy.sparse._sparsetools import coo_tocsr
labels = labels.ravel()
n = len(labels)
nlabels = np.max(labels) + 1
indices = np.arange(n)
sorted_indices = np.empty(n, int)
offsets = np.zeros(nlabels+1, int)
dummy = np.zeros(n, int)
coo_tocsr(nlabels, 1, n, labels, dummy, indices,
offsets, dummy, sorted_indices)
return sorted_indices, offsets
可以找到coo_tocsr
的来源here。我使用它的方式,它基本上执行间接counting sort。说实话,这是一个相当模糊的方法,我建议你在其他答案中使用其中一种方法。