Python将文件中的每一行与其他所有行进行比较

时间:2016-02-13 06:38:33

标签: python

我正在实施一个统计程序,并且已经创建了一个性能瓶颈,并希望我能从社区获得一些帮助,从而可能指出我的优化方向。

我正在为文件中的每一行创建一个集合,并通过比较同一文件中每行的集合数据来查找该集合的交集。然后我使用该交集的大小来过滤输出中的某些集合。问题是我有一个嵌套的for循环(O(n 2 )),并且进入程序的文件的标准大小刚刚超过20,000行。我已经计算了算法,并且在大约20分钟内它运行不到500行,但对于大文件,它需要大约8小时才能完成。

我有16GB的RAM和一个非常快的4核Intel i7处理器。通过复制list1并使用第二个列表进行比较而不是再次打开文件,我注意到内存使用没有显着差异(可能这是因为我有一个SSD?)。我认为'打开'机制直接读取/写入HDD,速度较慢但在使用两个列表时没有注意到差异。实际上,该程序在操作期间很少使用超过1GB的RAM。

我希望其他人使用某种数据类型,或者更好地理解Python中的多处理,并且他们可以帮助我加快速度。我感谢任何帮助,我希望我的代码写得太糟糕。

import ast, sys, os, shutil
list1 = []
end = 0
filterValue = 3

# creates output file with filterValue appended to name
with open(arg2 + arg1 + "/filteredSets" + str(filterValue) , "w") as outfile:
    with open(arg2 + arg1 + "/file", "r") as infile:
        # create a list of sets of rows in file
        for row in infile:
            list1.append(set(ast.literal_eval(row)))

            infile.seek(0)
            for row in infile:
                # if file only has one row, no comparisons need to be made
                if not(len(list1) == 1):
                # get the first set from the list and...
                    set1 = set(ast.literal_eval(row))
                    # ...find the intersection of every other set in the file
                    for i in range(0, len(list1)):
                        # don't compare the set with itself
                        if not(pos == i):
                            set2 = list1[i]
                            set3 = set1.intersection(set2)
                            # if the two sets have less than 3 items in common
                            if(len(set3) < filterValue):
                                # and you've reached the end of the file
                                if(i == len(list1)):
                                    # append the row in outfile
                                    outfile.write(row)
                                    # increase position in infile
                                    pos += 1
                            else:
                                break
                        else:
                            outfile.write(row)

示例输入将是具有以下格式的文件:

[userID1, userID2, userID3]
[userID5, userID3, userID9]
[userID10, userID2, userID3, userID1]
[userID8, userID20, userID11, userID1]

输出文件如果是输入文件将是:

[userID5, userID3, userID9]
[userID8, userID20, userID11, userID1]

...因为删除的两个集合包含三个或更多相同的用户ID。

Visual Example

2 个答案:

答案 0 :(得分:0)

这个答案不是关于如何在函数,名称变量等中分割代码。它是关于复杂性的更快算法。

我会用字典。不会编写确切的代码,你可以自己动手。

Sets = dict()
for rowID, row in enumerate(Rows):
  for userID in row:
     if Sets.get(userID) is None:
       Sets[userID] = set()
     Sets[userID].add(rowID)

所以,现在我们有了一个字典,可以用来快速获取包含给定userID的行的行数。

BadRows = set()
for rowID, row in enumerate(Rows):
  Intersections = dict()
  for userID in row:
    for rowID_cmp in Sets[userID]: 
      if rowID_cmp != rowID:
        Intersections[rowID_cmp] = Intersections.get(rowID_cmp, 0) + 1
  # Now Intersections contains info about how many "times"
  # row numbered rowID_cmp intersectcs current row
  filteredOut = False
  for rowID_cmp in Intersections:
    if Intersections[rowID_cmp] >= filterValue:
      BadRows.add(rowID_cmp)
      filteredOut = True
  if filteredOut:
    BadRows.add(rowID)

将所有已过滤行的rownumbers保存到BadRows,现在我们最后一次迭代:

for rowID, row in enumerate(Rows):
  if rowID not in BadRows:
    # output row

这适用于3次扫描和O(nlogn)时间。也许你必须重做迭代Rows数组,因为它是你的情况下的一个文件,但并没有真正改变太多。

不确定python的语法和细节,但你明白了我的代码。

答案 1 :(得分:-1)

首先,请将您的代码打包到功能齐全的功能中。

def get_data(*args):
    # get the data.

def find_intersections_sets(list1, list2):
    # do the intersections part.

def loop_over_some_result(result):
    # insert assertions so that you don't end up looping in infinity:
    assert result is not None
    ...

def myfunc(*args):
    source1, source2 = args
    L1, L2 = get_data(source1), get_data(source2)
    intersects = find_intersections_sets(L1,L2)
    ...

if __name__ == "__main__":
    myfunc()

然后您可以使用以下方式轻松地分析代码:

if __name__ == "__main__":
    import cProfile
    cProfile.run('myfunc()')

它为您提供了对您的代码行为的宝贵见解,并允许您跟踪逻辑错误。有关cProfile的更多信息,请参阅How can you profile a python script?

追踪逻辑缺陷的选项(我们是所有人类,对吗?)是在this (python2)this (python3)这样的装饰中使用超时功能:

此时myfunc可以更改为:

def get_data(*args):
    # get the data.

def find_intersections_sets(list1, list2):
    # do the intersections part.

def myfunc(*args):
    source1, source2 = args
    L1, L2 = get_data(source1), get_data(source2)

    @timeout(10) # seconds <---- the clever bit!
    intersects = find_intersections_sets(L1,L2)
    ...

...如果超时操作需要太长时间,则会引发错误。

这是我最好的猜测:

import ast 

def get_data(filename):
    with open(filename, 'r') as fi:
        data = fi.readlines()
    return data

def get_ast_set(line):
    return set(ast.literal_eval(line))

def less_than_x_in_common(set1, set2, limit=3):
    if len(set1.intersection(set2)) < limit:
        return True
    else:
        return False

def check_infile(datafile, savefile, filtervalue=3):
    list1 = [get_ast_set(row) for row in get_data(datafile)]
    outlist = []
    for row in list1:
        if any([less_than_x_in_common(set(row), set(i), limit=filtervalue) for i in outlist]):
            outlist.append(row)
    with open(savefile, 'w') as fo:
        fo.writelines(outlist)

if __name__ == "__main__":
    datafile = str(arg2 + arg1 + "/file")
    savefile = str(arg2 + arg1 + "/filteredSets" + str(filterValue))
    check_infile(datafile, savefile)