提高这个python函数的性能

时间:2015-01-21 19:15:03

标签: python performance numpy

def _partition_by_context(self, labels, contexts):
    # partition the labels by context
    assert len(labels) == len(contexts)
    by_context = collections.defaultdict(list)
    for i, label in enumerate(labels):
        by_context[contexts[i]].append(label)

    # now remove any that don't have enough samples
    keys_to_remove = []
    for key, value in by_context.iteritems():
        if len(value) < self._min_samples_context:
            keys_to_remove.append(key)
    for key in keys_to_remove:
        del by_context[key]

    return by_context
  1. 标签是浮动数组。
  2. contexts是一个元组的python列表。每个元组的格式为(unicode, int):示例(u'ffcd6881167b47d492adf3f542af94c6', 2)。上下文值经常重复。例如,上下文列表中可能有10000个值,但只有100个不同的值。
  3. len(labels) == len(contexts)为真,如第一行所述
  4. 索引i处的
  5. 标签与索引i处的上下文相关联。也就是说,labels[i]contexts[i]“齐心协力”
  6. 此功能的作用是按上下文值对标签中的值进行分区。然后在最后,如果标签计数太低,则删除字典条目。

    因此,如果所有上下文值都相同,则返回值将是具有单个条目的字典,key = context,value =所有标签的列表。

    如果有N个不同的上下文值,则返回值将具有N个键(每个上下文一个),并且每个键的值将是与特定上下文相关联的标签列表。列表中的标签排序并不重要。

    使用不同的args调用此函数数百万次。我已经确定它是使用gprof2dot的瓶颈。大部分成本都在第一个for循环中的list append()调用中。

    谢谢!

2 个答案:

答案 0 :(得分:1)

尝试替换

    for i, label in enumerate(labels):
        by_context[contexts[i]].append(label)

for context, label in zip(contexts, labels):
    by_context[context].append(label)

而不是使用keys_to_remove,请尝试

n = self._min_samples_context
return {c:ls for c,ls in by_context.items() if len(ls) >= n}

答案 1 :(得分:0)

看起来像这两个数组的东西会成为一个很好的测试用例:

N = 100
labels=np.arange(N)
contexts=np.random.randint(0,len(labels)/10,len(labels))

通过这些阵列,@ Hugh的改进速度提高了约10%。

我对其他问题的经验表明,defaultdict是收集这样的值的一种非常好的方式。唯一可能更快的是将其转换为某种numpy索引问题。