建议我的代码中更快的for / if语句?

时间:2013-08-19 06:44:58

标签: python performance if-statement for-loop numpy

我的代码需要大约两个小时来处理。瓶颈在于for循环和if 语句(参见代码中的注释)。 我是python的初学者:)任何人都可以推荐一种有效的python方法来替换嵌套的for和if语句吗?

我有大约3000万行的表,每行有(x,y,z)值:

  

20.0 11.3 7
  21.0 11.3 0
  22.0 11.3 3
  ......

我想要的输出是x,y,min(z),count(min(z))形式的表格。最后 column是该(x,y)处最小z值的最终计数。例如:

  

20.0 11.3 7 7
  21.0 11.3 0 10
  22.0 11.3 3 1
  ......

只有大约600个唯一坐标,因此输出表将为600x4。 我的代码:

import numpy as np
file = open('input.txt','r');

coordset = set()
data = np.zeros((600,4))*np.nan
irow = 0 
ctr = 0 

for row in file:
    item = row.split()
    x = float(item[0])
    y = float(item[1])
    z = float(item[2])

    # build unique grid of coords
    if ((x,y)) not in coordset:
        data[irow][0] = x 
        data[irow][1] = y 
        data[irow][2] = z 
        irow = irow + 1     # grows up to 599 

    # lookup table of unique coords
    coordset.add((x,y))

    # BOTTLENECK. replace ifs? for?
    for i in range(0, irow):
        if data[i][0]==x and data[i][1]==y:
            if z > data[i][2]:
                continue
            elif z==data[i][2]:
                ctr = ctr + 1
                data[i][3]=ctr
            if z < data[i][2]:
                data[i][2] = z
                ctr = 1
                data[i][3]=ctr

编辑:@Joowani的方法在1分26秒内进行计算。我的原始方法,相同的计算机,相同的数据文件,106m23s。 edit2: @Ophion和@Sibster感谢您的建议,我没有足够的信誉来+1有用的答案。

3 个答案:

答案 0 :(得分:2)

您的解决方案似乎很慢,因为它会在您进行更新的每个时间内遍历列表(即数据)。更好的方法是使用字典,每次更新需要O(1)而不是O(n)。

这是我使用字典的解决方案:

file = open('input.txt', 'r')

#coordinates
c = {}

for line in file:
    #items
    (x, y, z) = (float(n) for n in line.split())

    if (x, y) not in c:
        c[(x, y)] = [z, 1]
    elif c[(x, y)][0] > z:
        c[(x, y)][0], c[(x, y)][1] = z, 1
    elif c[(x, y)][0] == z:
        c[(x, y)][1] += 1

for key in c:
    print("{} {} {} {}".format(key[0], key[1], c[key][0], c[key][1]))

答案 1 :(得分:0)

为什么不将最后一个if改为elif?

就像现在这样,你将评估循环的每次迭代z < data[i][2]:

您甚至可以用其他内容替换它,因为您已经选中if z>data[i][2]z == data[i][2],因此唯一剩下的可能性是z < data[i][2]:

所以下面的代码会做同样的事情并且应该更快:

        if z > data[i][2]:
            continue
        elif z==data[i][2]:
            ctr = ctr + 1
            data[i][3]=ctr
        else:
            data[i][2] = z
            ctr = 1
            data[i][3]=ctr

答案 2 :(得分:0)

要在numpy使用np.unique中执行此操作。

def count_unique(arr):
    row_view=np.ascontiguousarray(a).view(np.dtype((np.void,a.dtype.itemsize * a.shape[1])))
    ua, uind = np.unique(row_view,return_inverse=True)
    unique_rows = ua.view(a.dtype).reshape(ua.shape + (-1,))
    count=np.bincount(uind)
    return np.hstack((unique_rows,count[:,None]))

首先让我们检查一个小数组:

a=np.random.rand(10,3)
a=np.around(a,0)

print a
[[ 0.  0.  0.]
 [ 0.  1.  1.]
 [ 0.  1.  0.]
 [ 1.  0.  0.]
 [ 0.  1.  1.]
 [ 1.  1.  0.]
 [ 1.  0.  1.]
 [ 1.  0.  1.]
 [ 1.  0.  0.]
 [ 0.  0.  0.]]

 print output
[[ 0.  0.  0.  2.]
 [ 0.  1.  0.  1.]
 [ 0.  1.  1.  2.]
 [ 1.  0.  0.  2.]
 [ 1.  0.  1.  2.]
 [ 1.  1.  0.  1.]]

 print np.sum(output[:,-1])
 10

看起来不错!现在让我们检查一个大型数组:

a=np.random.rand(3E7,3)
a=np.around(a,1)

output=count_unique(a)
print output.shape
(1331, 4)  #Close as I can get to 600 unique elements.

print np.sum(output[:,-1])
30000000.0

在我的机器和3GB内存上大约需要33秒,在内存中为大型阵列执行此操作可能是您的瓶颈。作为参考,@ Joowani的解决方案花了大约130秒,虽然这是一个苹果和橙子的比较,因为我们从一个numpy数组开始。你的milage可能会有所不同。

要将数据作为numpy数组读入,我会查看问题here,但它应该如下所示:

arr=np.genfromtxt("./input.txt", delimiter=" ")

从txt文件加载大量数据我真的建议使用该链接中的pandas示例。