对numpy数组的所有值执行操作,引用i和j

时间:2012-05-30 17:43:23

标签: python arrays performance numpy indexing

我试图通过在2d数组上应用操作来改善numpy性能,问题是数组中每个元素的值取决于该元素的i,j位置。

显然,这样做的简单方法是使用嵌套的for循环,但我想知道是否有更好的方法可以通过引用np.indices或沿着这些行的东西?这是我的“愚蠢”代码:

for J in range(1025):
    for I in range(1025):
        PSI[I][J] = A*math.sin((float(I+1)-.5)*DI)*math.sin((float(J+1)-.5)*DJ)
        P[I][J] = PCF*(math.cos(2.*float(I)*DI)+math.cos(2.*float(J)*DJ))+50000.

3 个答案:

答案 0 :(得分:7)

由于您在两个数组之间进行乘法运算,因此在使用arange获取sin / cos数组后,可以使用outer函数。

像这样的东西(使用numpy的trig函数,因为它们是矢量化的)

PSI_i = numpy.sin((arange(1,1026)-0.5)*DI)
PSI_j = numpy.sin((arange(1,1026)-0.5)*DJ)
PSI = A*outer(PSI_i, PSI_j)

P_i = numpy.cos(2.*arange(1,1026)*DI)
P_j = numpy.cos(2.*arange(1,1026)*DJ)
P = PCF*outer(P_i, P_j) + 50000

如果您的环境是使用from numpy import *from pylab import *设置的,那么在您的trig功能之前不需要那些numpy.前缀。我让他们把它们与math区别开来,这对这种方法不起作用。

答案 1 :(得分:1)

您可以使用索引获取索引值的网格:

I,J=np.indices(PSI.shape)
#All constants set to one
PSI2=np.sin(I+1-.5)*np.sin(J+1-.5)
print PSI-PSI2 # should be zero.

我用ipython做了一些时间:

import numpy as np
import math
A = 1
P = 1
DI = 1
DJ = 1

def a():
    PSI=np.zeros((1025,1025))
    for J in range(1025):
        for I in range(1025):
            PSI[I][J] = A*math.sin((float(I+1)-.5)*DI)*math.sin((float(J+1)-.5)*DJ)
%timeit a()

def b():
    PSI=np.zeros((1025,1025))
    for I,J in np.ndindex(*PSI.shape):
        PSI[I,J] = A*math.sin((float(I+1)-.5)*DI)*math.sin((float(J+1)-.5)*DJ)        
%timeit b()

def c():
    I,J=np.indices((1025, 1025))
    P2=A*np.sin((I+1-.5)*DI)*np.sin((J+1-.5)*DJ)    
%timeit c()

def d():
    PSI_i = np.sin((np.arange(1,1026)-0.5)*DI)
    PSI_j = np.sin((np.arange(1,1026)-0.5)*DJ)
    PSI = A*np.outer(PSI_i, PSI_j)    
%timeit d()

结果在我的机器上并不令人惊讶:

1 loops, best of 3: 1.75 s per loop
1 loops, best of 3: 3.51 s per loop
10 loops, best of 3: 77.1 ms per loop
100 loops, best of 3: 7.16 ms per loop

答案 2 :(得分:0)

尝试numpy的ndenumerate函数,它返回值以及索引:

>>> a
array([[5, 5, 5],
       [1, 2, 3]])


>>> for index, value in numpy.ndenumerate(a):
...     print index, value

(0, 0) 5
(0, 1) 5
(0, 2) 5
(1, 0) 1
(1, 1) 2
(1, 2) 3