在python中计算数组上if语句的快速方法?

时间:2015-01-24 05:08:01

标签: python arrays numpy

假设三个numpy数组x,y和z

    z = (x**2)/ y          for each  x > 2 y
    z = (x**2)/y**(3/2)    for each  x > 3 y
    z = (1/x)*sin(x)       for each  x > 4 y

数组x,y和z当然是由当然组成的,但它们说明了在多个数组上运行多个if语句的要点。数组x,y和z各约为500,000个元素。

一种可能的方式(很像FORTRAN)是创建一个变量i来索引数组并使用它来测试x [i]> 2 * y [i]或x [i]> 3 * Y [i]中。我认为它会很慢。

我需要一种快速,优雅且更加pythonic的方式来计算数组z。

更新:我尝试了两种方法,结果如下:

   # Fortran way of loops: 
   import numpy as np

   x=np.random.rand(40000,1)
   y=np.random.rand(40000,1)

   z = np.zeros(x.shape)
   for i, v in enumerate(x):
        #print i
        if x[i] >2*y[i]:
            z[i]= x[i]**2/y[i]
        if x[i] > 3*y[i]:
            z[i]=x[i]**2/y[i]**(1.5)
        if x[i] > 4*y[i]:
            z[i] = (1/x[i])*np.sin(x[i])

    z = np.zeros(x.shape)
    print z
    #end----

时间结果如下:

    real    0m0.920s
    user    0m0.900s
     sys    0m0.016s

使用的另一段代码是:

    # Pythonic way
    import numpy as np

    x=np.random.rand(40000,1)
    y=np.random.rand(40000,1)

    indices1 = np.where(x > 2*y)
    indices2 = np.where(x > 3*y)
    indices3 = np.where(x > 4*y)

    z = np.zeros(x.shape)
    z[indices1] = x[indices1]**2/y[indices1]
    z[indices2] = x[indices2]**2/y[indices2]**(1.5)
    z[indices3] = (1/x[indices3])*np.sin(x[indices3]) 
    print z
    # end of code -----

时间结果如下:

    real    0m0.110s
    user    0m0.076s
     sys    0m0.028s

因此执行时间差异很大。这两个部分在使用python 2.7.5的ubuntu虚拟机上运行

更新:我使用

进行了另一项测试
    indices1 = x > 2*y
    indices2 = x > 3*y
    indices3 = x > 4*y

时间结果如下:

     real   0m0.105s
     user   0m0.084s
      sys   0m0.016s

总结:方法3比使用np.where更优雅,更快。使用显式循环非常慢。

1 个答案:

答案 0 :(得分:2)

我不太确定你的z阵列是否与x或y的大小相同,但我会假设。

Numpy有一个函数可以根据条件找到元素的索引。 在下面的示例中,我正在进行类似于第一行的计算。

import numpy as np

x = np.arange(4)
x[2:] += 10
print x

y = np.arange(4)
print y

indices = np.where(x > 2*y)
print indices

z = np.zeros(x.shape)
z[indices] = x[indices]**2/y[indices]
print z

print语句产生以下结果:

x:[0 1 12 13]

y:[0 1 2 3]

指数:[2,3]

z:[0 0 72 56]

编辑: 经过进一步的测试,结果证明你甚至不需要使用numpy where功能。你可以简单地设置indices = x> 2 * Y