Numpy,长数组的问题

时间:2009-11-08 19:07:52

标签: python math numpy

我有两个数组(a和b),其中n个整数元素的范围是(0,N)。

错误:具有2 ^ n个整数的数组,其中最大整数的值为N = 3 ^ n

我想计算a和b中每个元素组合的总和(sum_ij_ = a_i_ + b_j_表示所有 i,j )。然后取模数N(sum_ij_ = sum_ij_%N),最后计算不同总和的频率。

为了快速执行numpy,没有任何循环,我尝试使用meshgrid和bincount函数。

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

现在,问题是我的输入数组很长。当我使用2 ^ 13个元素的输入时,meshgrid给了我MemoryError。我想为具有2 ^ 15-2 ^ 20个元素的数组计算这个。

n为15到20

用numpy做这个有什么聪明的技巧吗?

任何帮助都将受到高度赞赏。

- 乔恩

1 个答案:

答案 0 :(得分:1)

编辑以回应jonalm的评论:

  

jonalm:N~3 ^ n不是n~3 ^ N. N是a中的max元素,n是数字   a。中的元素。

n是~2 ^ 20。如果N是~3 ^ n那么N是〜3 ^(2 ^ 20)> 10 ^(500207)。 科学家估计(http://www.stormloader.com/ajy/reallife.html)宇宙中只有大约10 ^ 87个粒子。所以没有(天真的)计算机可以处理大小为10 ^(500207)的int。

  

jonalm:我对你定义的pv()函数有点古怪。 (一世   不设法运行它,因为text.find()没有定义(猜测它在另一个   模块))。这个功能如何运作?它有什么优势?

pv是我编写的用于调试变量值的小辅助函数。它很像 print(),除非你说pv(x),它打印文字变量名(或表达式字符串),冒号,然后是变量的值。

如果你把

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

在脚本中你应该得到

x: 1

使用pv而不是打印的最小优点是它可以节省您的打字。而不是必须 写

print('x: %s'%x)

你可以打下来

pv(x)

当要跟踪多个变量时,标记变量很有帮助。 我只是厌倦了全力以赴。

pv函数的工作原理是使用traceback模块来查看代码行 用于调用pv函数本身。 (参见http://docs.python.org/library/traceback.html#module-traceback)该行代码作为字符串存储在变量文本中。 text.find()是对通常的字符串方法find()的调用。例如,如果

text='pv(x)'

然后

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

我假设n~3 ^ N,n~2 ** 20

这个想法是工作模块N.这减少了数组的大小。 第二个想法(当n很大时很重要)是使用'object'类型的numpy ndarrays,因为如果使用整数dtype,则存在溢出允许的最大整数大小的风险。

#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

您可以将n更改为2 ** 20,但在下面我会显示小n会发生什么 所以输出更容易阅读。

n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa保存a中的0,1s,2s,3s的数量 wb保存b中的0,1,2,3的数量

wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

将0视为令牌或芯片。同样适用于1,2,3。

想想wa = [24 28 28 20]意味着有一个包含24个0芯片,28个1芯片,28个2芯片,20个3芯片的包。

你有一个wa-bag和一个wb-bag。当您从每个包中绘制一个芯片时,您将它们“添加”在一起并形成一个新芯片。你“修改”答案(模N)。

想象一下从wb-bag中取出1片芯片并将其与每个芯片一起添加到wa-bag中。

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

由于wb包中有34个1芯片,当你将它们添加到wa = [24 28 28 20]包中的所有芯片时,你得到了

34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

由于34个1芯片,这只是部分计数。你还必须处理另一个 wb-bag中的芯片类型,但是这显示了下面使用的方法:

for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

这是最终结果:2444 0s,2504 1s,2520 2s,2532 3s。

# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]