def _partition_by_context(self, labels, contexts):
# partition the labels by context
assert len(labels) == len(contexts)
by_context = collections.defaultdict(list)
for i, label in enumerate(labels):
by_context[contexts[i]].append(label)
# now remove any that don't have enough samples
keys_to_remove = []
for key, value in by_context.iteritems():
if len(value) < self._min_samples_context:
keys_to_remove.append(key)
for key in keys_to_remove:
del by_context[key]
return by_context
(unicode, int)
:示例(u'ffcd6881167b47d492adf3f542af94c6', 2)
。上下文值经常重复。例如,上下文列表中可能有10000个值,但只有100个不同的值。len(labels) == len(contexts)
为真,如第一行所述labels[i]
和contexts[i]
“齐心协力”此功能的作用是按上下文值对标签中的值进行分区。然后在最后,如果标签计数太低,则删除字典条目。
因此,如果所有上下文值都相同,则返回值将是具有单个条目的字典,key = context,value =所有标签的列表。
如果有N个不同的上下文值,则返回值将具有N个键(每个上下文一个),并且每个键的值将是与特定上下文相关联的标签列表。列表中的标签排序并不重要。
使用不同的args调用此函数数百万次。我已经确定它是使用gprof2dot的瓶颈。大部分成本都在第一个for循环中的list append()调用中。
谢谢!
答案 0 :(得分:1)
尝试替换
for i, label in enumerate(labels):
by_context[contexts[i]].append(label)
与
for context, label in zip(contexts, labels):
by_context[context].append(label)
而不是使用keys_to_remove
,请尝试
n = self._min_samples_context
return {c:ls for c,ls in by_context.items() if len(ls) >= n}
答案 1 :(得分:0)
看起来像这两个数组的东西会成为一个很好的测试用例:
N = 100
labels=np.arange(N)
contexts=np.random.randint(0,len(labels)/10,len(labels))
通过这些阵列,@ Hugh的改进速度提高了约10%。
我对其他问题的经验表明,defaultdict
是收集这样的值的一种非常好的方式。唯一可能更快的是将其转换为某种numpy
索引问题。