我正在实施一个统计程序,并且已经创建了一个性能瓶颈,并希望我能从社区获得一些帮助,从而可能指出我的优化方向。
我正在为文件中的每一行创建一个集合,并通过比较同一文件中每行的集合数据来查找该集合的交集。然后我使用该交集的大小来过滤输出中的某些集合。问题是我有一个嵌套的for循环(O(n 2 )),并且进入程序的文件的标准大小刚刚超过20,000行。我已经计算了算法,并且在大约20分钟内它运行不到500行,但对于大文件,它需要大约8小时才能完成。
我有16GB的RAM和一个非常快的4核Intel i7处理器。通过复制list1并使用第二个列表进行比较而不是再次打开文件,我注意到内存使用没有显着差异(可能这是因为我有一个SSD?)。我认为'打开'机制直接读取/写入HDD,速度较慢但在使用两个列表时没有注意到差异。实际上,该程序在操作期间很少使用超过1GB的RAM。
我希望其他人使用某种数据类型,或者更好地理解Python中的多处理,并且他们可以帮助我加快速度。我感谢任何帮助,我希望我的代码写得太糟糕。
import ast, sys, os, shutil
list1 = []
end = 0
filterValue = 3
# creates output file with filterValue appended to name
with open(arg2 + arg1 + "/filteredSets" + str(filterValue) , "w") as outfile:
with open(arg2 + arg1 + "/file", "r") as infile:
# create a list of sets of rows in file
for row in infile:
list1.append(set(ast.literal_eval(row)))
infile.seek(0)
for row in infile:
# if file only has one row, no comparisons need to be made
if not(len(list1) == 1):
# get the first set from the list and...
set1 = set(ast.literal_eval(row))
# ...find the intersection of every other set in the file
for i in range(0, len(list1)):
# don't compare the set with itself
if not(pos == i):
set2 = list1[i]
set3 = set1.intersection(set2)
# if the two sets have less than 3 items in common
if(len(set3) < filterValue):
# and you've reached the end of the file
if(i == len(list1)):
# append the row in outfile
outfile.write(row)
# increase position in infile
pos += 1
else:
break
else:
outfile.write(row)
示例输入将是具有以下格式的文件:
[userID1, userID2, userID3]
[userID5, userID3, userID9]
[userID10, userID2, userID3, userID1]
[userID8, userID20, userID11, userID1]
输出文件如果是输入文件将是:
[userID5, userID3, userID9]
[userID8, userID20, userID11, userID1]
...因为删除的两个集合包含三个或更多相同的用户ID。
答案 0 :(得分:0)
这个答案不是关于如何在函数,名称变量等中分割代码。它是关于复杂性的更快算法。
我会用字典。不会编写确切的代码,你可以自己动手。
Sets = dict()
for rowID, row in enumerate(Rows):
for userID in row:
if Sets.get(userID) is None:
Sets[userID] = set()
Sets[userID].add(rowID)
所以,现在我们有了一个字典,可以用来快速获取包含给定userID的行的行数。
BadRows = set()
for rowID, row in enumerate(Rows):
Intersections = dict()
for userID in row:
for rowID_cmp in Sets[userID]:
if rowID_cmp != rowID:
Intersections[rowID_cmp] = Intersections.get(rowID_cmp, 0) + 1
# Now Intersections contains info about how many "times"
# row numbered rowID_cmp intersectcs current row
filteredOut = False
for rowID_cmp in Intersections:
if Intersections[rowID_cmp] >= filterValue:
BadRows.add(rowID_cmp)
filteredOut = True
if filteredOut:
BadRows.add(rowID)
将所有已过滤行的rownumbers保存到BadRows,现在我们最后一次迭代:
for rowID, row in enumerate(Rows):
if rowID not in BadRows:
# output row
这适用于3次扫描和O(nlogn)时间。也许你必须重做迭代Rows数组,因为它是你的情况下的一个文件,但并没有真正改变太多。
不确定python的语法和细节,但你明白了我的代码。
答案 1 :(得分:-1)
首先,请将您的代码打包到功能齐全的功能中。
def get_data(*args):
# get the data.
def find_intersections_sets(list1, list2):
# do the intersections part.
def loop_over_some_result(result):
# insert assertions so that you don't end up looping in infinity:
assert result is not None
...
def myfunc(*args):
source1, source2 = args
L1, L2 = get_data(source1), get_data(source2)
intersects = find_intersections_sets(L1,L2)
...
if __name__ == "__main__":
myfunc()
然后您可以使用以下方式轻松地分析代码:
if __name__ == "__main__":
import cProfile
cProfile.run('myfunc()')
它为您提供了对您的代码行为的宝贵见解,并允许您跟踪逻辑错误。有关cProfile的更多信息,请参阅How can you profile a python script?
追踪逻辑缺陷的选项(我们是所有人类,对吗?)是在this (python2)或this (python3)这样的装饰中使用超时功能:
此时myfunc
可以更改为:
def get_data(*args):
# get the data.
def find_intersections_sets(list1, list2):
# do the intersections part.
def myfunc(*args):
source1, source2 = args
L1, L2 = get_data(source1), get_data(source2)
@timeout(10) # seconds <---- the clever bit!
intersects = find_intersections_sets(L1,L2)
...
...如果超时操作需要太长时间,则会引发错误。
这是我最好的猜测:
import ast
def get_data(filename):
with open(filename, 'r') as fi:
data = fi.readlines()
return data
def get_ast_set(line):
return set(ast.literal_eval(line))
def less_than_x_in_common(set1, set2, limit=3):
if len(set1.intersection(set2)) < limit:
return True
else:
return False
def check_infile(datafile, savefile, filtervalue=3):
list1 = [get_ast_set(row) for row in get_data(datafile)]
outlist = []
for row in list1:
if any([less_than_x_in_common(set(row), set(i), limit=filtervalue) for i in outlist]):
outlist.append(row)
with open(savefile, 'w') as fo:
fo.writelines(outlist)
if __name__ == "__main__":
datafile = str(arg2 + arg1 + "/file")
savefile = str(arg2 + arg1 + "/filteredSets" + str(filterValue))
check_infile(datafile, savefile)