我有两个数组(a和b),其中n个整数元素的范围是(0,N)。
错误:具有2 ^ n个整数的数组,其中最大整数的值为N = 3 ^ n
我想计算a和b中每个元素组合的总和(sum_ij_ = a_i_ + b_j_表示所有 i,j )。然后取模数N(sum_ij_ = sum_ij_%N),最后计算不同总和的频率。
为了快速执行numpy,没有任何循环,我尝试使用meshgrid和bincount函数。
A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)
现在,问题是我的输入数组很长。当我使用2 ^ 13个元素的输入时,meshgrid给了我MemoryError。我想为具有2 ^ 15-2 ^ 20个元素的数组计算这个。
n为15到20
用numpy做这个有什么聪明的技巧吗?
任何帮助都将受到高度赞赏。
- 乔恩
答案 0 :(得分:1)
编辑以回应jonalm的评论:
jonalm:N~3 ^ n不是n~3 ^ N. N是a中的max元素,n是数字 a。中的元素。
n是~2 ^ 20。如果N是~3 ^ n那么N是〜3 ^(2 ^ 20)> 10 ^(500207)。 科学家估计(http://www.stormloader.com/ajy/reallife.html)宇宙中只有大约10 ^ 87个粒子。所以没有(天真的)计算机可以处理大小为10 ^(500207)的int。
jonalm:我对你定义的pv()函数有点古怪。 (一世 不设法运行它,因为text.find()没有定义(猜测它在另一个 模块))。这个功能如何运作?它有什么优势?
pv是我编写的用于调试变量值的小辅助函数。它很像 print(),除非你说pv(x),它打印文字变量名(或表达式字符串),冒号,然后是变量的值。
如果你把
#!/usr/bin/env python
import traceback
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)
在脚本中你应该得到
x: 1
使用pv而不是打印的最小优点是它可以节省您的打字。而不是必须 写
print('x: %s'%x)
你可以打下来
pv(x)
当要跟踪多个变量时,标记变量很有帮助。 我只是厌倦了全力以赴。
pv函数的工作原理是使用traceback模块来查看代码行 用于调用pv函数本身。 (参见http://docs.python.org/library/traceback.html#module-traceback)该行代码作为字符串存储在变量文本中。 text.find()是对通常的字符串方法find()的调用。例如,如果
text='pv(x)'
然后
text.find('(') == 2 # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x' # Everything in between the parentheses
我假设n~3 ^ N,n~2 ** 20
这个想法是工作模块N.这减少了数组的大小。 第二个想法(当n很大时很重要)是使用'object'类型的numpy ndarrays,因为如果使用整数dtype,则存在溢出允许的最大整数大小的风险。
#!/usr/bin/env python
import traceback
import numpy as np
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
您可以将n更改为2 ** 20,但在下面我会显示小n会发生什么 所以输出更容易阅读。
n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4
a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
# 1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
# 0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
# 3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
# 3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]
wa保存a中的0,1s,2s,3s的数量 wb保存b中的0,1,2,3的数量
wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')
将0视为令牌或芯片。同样适用于1,2,3。
想想wa = [24 28 28 20]意味着有一个包含24个0芯片,28个1芯片,28个2芯片,20个3芯片的包。
你有一个wa-bag和一个wb-bag。当您从每个包中绘制一个芯片时,您将它们“添加”在一起并形成一个新芯片。你“修改”答案(模N)。
想象一下从wb-bag中取出1片芯片并将其与每个芯片一起添加到wa-bag中。
1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip (we are mod'ing by N=4)
由于wb包中有34个1芯片,当你将它们添加到wa = [24 28 28 20]包中的所有芯片时,你得到了
34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips
由于34个1芯片,这只是部分计数。你还必须处理另一个 wb-bag中的芯片类型,但是这显示了下面使用的方法:
for i,count in enumerate(wb):
partial_count=count*wa
pv(partial_count)
shifted_partial_count=np.roll(partial_count,i)
pv(shifted_partial_count)
result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]
pv(result)
# result: [2444 2504 2520 2532]
这是最终结果:2444 0s,2504 1s,2520 2s,2532 3s。
# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
print('n is too large to run the check')
else:
c=(a[:]+b[:,np.newaxis])
c=c.ravel()
c=c%N
result2=np.bincount(c)
pv(result2)
assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]