python help - 数组上的条件数学运算

时间:2014-08-03 14:19:19

标签: python-2.7

我试图学习python - 但对OOP来说是新手。我想让函数fA,fB,fC同时在整个r-theta空间上运行,而不是一次运行一个点。我的问题是条件(r <= 1)。下面的代码非常难看(!)但是它有效。

我怎样才能使这更像pythonesque?谢谢!

(在这个简化的例子中,注意当r变为零时,(r> 1)的数学运算会发散)

from math import pi, sin, cos, exp
import numpy as np
import matplotlib.pyplot as plt

def fA(rr,th,a,b,c):
    if (rr<=1):
        fx = a * sin(th)
        fy = b * rr * cos(th)
        fz = c * rr
    else:
        fx = (a / rr) * sin(th)
        fy = (b / rr) * cos(th)
        fz = (c / rr)
    return(fx,fy,fz)

def fB(rr,th,a,b,c):
    if (rr<=1):
        fx = b * sin(2.*th)
        fy = a * rr * cos(2.*th)
        fz = c * rr
    else:
        fx = (b / rr) * sin(2.*th)
        fy = (a / rr) * cos(2.*th)
        fz = c
    return(fx,fy,fz)

def fC(rr,th,a,b,c):
    if (rr<=1):
        fx = exp(rr - 1.) * cos(th)
        fy = exp(rr - 1.) * sin(th)
        fz = c 
    else:
        fx = exp(1. - rr) * cos(th)
        fy = exp(1. - rr) * sin(th)
        fz = c / rr
    return(fx,fy,fz)


nx = 101
ny = 101
dx = 4. / (nx-1)
dy = 4. / (ny-1)

X = np.zeros((ny,nx))
Y = np.zeros((ny,nx))

for ix in range(nx):
    for iy in range(ny):
        X[iy,ix] = dx*(ix - (nx-1)/2)
        Y[iy,ix] = dy*(iy - (ny-1)/2)

r = np.sqrt(X**2. + Y**2.)
theta = np.arctan2(Y,X)

Ax = np.zeros((ny,nx))
Ay = np.zeros((ny,nx))
Az = np.zeros((ny,nx))

Bx = np.zeros((ny,nx))
By = np.zeros((ny,nx))
Bz = np.zeros((ny,nx))

Cx = np.zeros((ny,nx))
Cy = np.zeros((ny,nx))
Cz = np.zeros((ny,nx))


for ix in range (nx):
    for iy in range(ny):
        Ax[iy,ix], Ay[iy,ix], Az[iy,ix] = fA(r[iy,ix], theta[iy,ix], 1.0, 1.0, 1.5)
        Bx[iy,ix], By[iy,ix], Bz[iy,ix] = fB(r[iy,ix], theta[iy,ix], 1.5, 0.8, 1.0)
        Cx[iy,ix], Cy[iy,ix], Cz[iy,ix] = fC(r[iy,ix], theta[iy,ix], 0.9, 1.1, 1.2)

plt.figure()

plt.subplot(3,3,1)
plt.imshow(Ax)
plt.colorbar()
plt.title('Ax')

plt.subplot(3,3,2)
plt.imshow(Ay)
plt.colorbar()
plt.title('Ay')

plt.subplot(3,3,3)
plt.imshow(Az)
plt.colorbar()
plt.title('Az')


plt.subplot(3,3,4)
plt.imshow(Bx)
plt.colorbar()
plt.title('Bx')

plt.subplot(3,3,5)
plt.imshow(By)
plt.colorbar()
plt.title('By')

plt.subplot(3,3,6)
plt.imshow(Bz)
plt.colorbar()
plt.title('Bz')


plt.subplot(3,3,7)
plt.imshow(Cx)
plt.colorbar()
plt.title('Cx')

plt.subplot(3,3,8)
plt.imshow(Cy)
plt.colorbar()
plt.title('Cy')

plt.subplot(3,3,9)
plt.imshow(Cz)
plt.colorbar()
plt.title('Cz')

plt.show()

2 个答案:

答案 0 :(得分:1)

随机选择你的一个关系,你可以使用numpy.where而不是if ...语义:

fx = where( rr<=1, exp(rr - 1.) * cos(th), exp(1. - rr) * cos(th))

对于一组数组而言,这实际上是if / else,而不是一次一个数字。然后,您可以在没有循环的情况下执行Ax,Ay,Az = fA(...)。

您可以使用meshgrid或mgrid制作X和Y.

要不评估所有元素,可以使用切片表示法

fx = empty_like(rr)
fx[rr<=1] = exp(rr[rr<=1] - 1.) * cos(th[rr<=1])
fx[rr>1] = exp(1. - rr[rr>1]) * cos(th[rr>1]))

答案 1 :(得分:1)

像@mdurant一​​样,np.wherenp.meshgrid会很有用。在这里,让我重新组织您的代码并使用numpy advanced indexing提供另一种避免Python循环的方法:

import sys
from math import pi, sin, cos, exp
import numpy as np
import matplotlib.pyplot as plt


def _generate_coordinate(nx, ny):
    """
    Generate coordinate data points in a function to prevent namespace
    pollution.
    """
    dx = 4. / (nx-1)
    dy = 4. / (ny-1)
    X = np.zeros((ny,nx))
    Y = np.zeros((ny,nx))
    for ix in range(nx):
        for iy in range(ny):
            X[iy,ix] = dx*(ix - (nx-1)/2)
            Y[iy,ix] = dy*(iy - (ny-1)/2)
    return np.sqrt(X**2 + Y**2), np.arctan2(Y,X)

nx = ny = 101
r, theta = _generate_coordinate(101, 101)


def calculate_numpy():
    # Helper methods for vector-based calculator.
    def fA(rr, th, a, b, c):
        # Calculate every value because mulplication doesn't give NaN.
        arrx = a * np.sin(th)
        arry = b * rr * np.cos(th)
        arrz = c * rr
        # Override value with a certain condition.
        slct = rr > 1
        rr = rr[slct]
        th = th[slct]
        arrx[slct] = a / rr * np.sin(th)
        arry[slct] = b / rr * np.cos(th)
        arrz[slct] = c / rr
        return arrx, arry, arrz
    def fB(rr, th, a, b, c):
        # Calculate every value because mulplication doesn't give NaN.
        arrx = b * np.sin(2.*th)
        arry = a * rr * np.cos(2.*th)
        arrz = c * rr
        # Override value with a certain condition.
        slct = rr > 1
        rr = rr[slct]
        th = th[slct]
        arrx[slct] = b / rr * np.sin(2.*th)
        arry[slct] = a / rr * np.cos(2.*th)
        arrz[slct] = c
        return arrx, arry, arrz
    def fC(rr,th,a,b,c):
        # Calculate every value because mulplication doesn't give NaN.
        arrx = np.exp(rr-1) * np.cos(th)
        arry = np.exp(rr-1) * np.sin(th)
        arrz = np.empty_like(rr)
        arrz.fill(c)
        # Override value with a certain condition.
        slct = rr > 1
        rr = rr[slct]
        th = th[slct]
        arrx[slct] = np.exp(1.-rr) * np.cos(th)
        arry[slct] = np.exp(1.-rr) * np.sin(th)
        arrz[slct] = c / rr
        return arrx, arry, arrz
    # Carry out calculation.
    Ax, Ay, Az = fA(r, theta, 1.0, 1.0, 1.5)
    Bx, By, Bz = fB(r, theta, 1.5, 0.8, 1.0)
    Cx, Cy, Cz = fC(r, theta, 0.9, 1.1, 1.2)
    return Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz


def calculate_loop():
    # Helper methods for loop calculator.
    def fA(rr,th,a,b,c):
        if (rr<=1):
            fx = a * sin(th)
            fy = b * rr * cos(th)
            fz = c * rr
        else:
            fx = (a / rr) * sin(th)
            fy = (b / rr) * cos(th)
            fz = (c / rr)
        return(fx,fy,fz)
    def fB(rr,th,a,b,c):
        if (rr<=1):
            fx = b * sin(2.*th)
            fy = a * rr * cos(2.*th)
            fz = c * rr
        else:
            fx = (b / rr) * sin(2.*th)
            fy = (a / rr) * cos(2.*th)
            fz = c
        return(fx,fy,fz)
    def fC(rr,th,a,b,c):
        if (rr<=1):
            fx = exp(rr - 1.) * cos(th)
            fy = exp(rr - 1.) * sin(th)
            fz = c 
        else:
            fx = exp(1. - rr) * cos(th)
            fy = exp(1. - rr) * sin(th)
            fz = c / rr
        return(fx,fy,fz)
    # Create buffer arrays for loops.
    Ax = np.zeros((ny,nx))
    Ay = np.zeros((ny,nx))
    Az = np.zeros((ny,nx))
    Bx = np.zeros((ny,nx))
    By = np.zeros((ny,nx))
    Bz = np.zeros((ny,nx))
    Cx = np.zeros((ny,nx))
    Cy = np.zeros((ny,nx))
    Cz = np.zeros((ny,nx))
    # Carry out calculation with Python loops.  This is slow.
    for ix in range (nx):
        for iy in range(ny):
            Ax[iy,ix], Ay[iy,ix], Az[iy,ix] = fA(r[iy,ix], theta[iy,ix], 1.0, 1.0, 1.5)
            Bx[iy,ix], By[iy,ix], Bz[iy,ix] = fB(r[iy,ix], theta[iy,ix], 1.5, 0.8, 1.0)
            Cx[iy,ix], Cy[iy,ix], Cz[iy,ix] = fC(r[iy,ix], theta[iy,ix], 0.9, 1.1, 1.2)
    return Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz


def main():
    calculate = globals()["calculate_" + sys.argv[1]]
    Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz = calculate()

    plt.figure()

    plt.subplot(3,3,1)
    plt.imshow(Ax)
    plt.colorbar()
    plt.title('Ax')

    plt.subplot(3,3,2)
    plt.imshow(Ay)
    plt.colorbar()
    plt.title('Ay')

    plt.subplot(3,3,3)
    plt.imshow(Az)
    plt.colorbar()
    plt.title('Az')


    plt.subplot(3,3,4)
    plt.imshow(Bx)
    plt.colorbar()
    plt.title('Bx')

    plt.subplot(3,3,5)
    plt.imshow(By)
    plt.colorbar()
    plt.title('By')

    plt.subplot(3,3,6)
    plt.imshow(Bz)
    plt.colorbar()
    plt.title('Bz')


    plt.subplot(3,3,7)
    plt.imshow(Cx)
    plt.colorbar()
    plt.title('Cx')

    plt.subplot(3,3,8)
    plt.imshow(Cy)
    plt.colorbar()
    plt.title('Cy')

    plt.subplot(3,3,9)
    plt.imshow(Cz)
    plt.colorbar()
    plt.title('Cz')

    plt.show()

if __name__ == '__main__':
    main()

函数calculate_numpy()是我演示高级索引的地方。如果您想完全避免重复计算,则需要像calculate_loop()中所做的那样创建缓冲区。但我会说在运行时方面,重复计算是可以的。

假设程序保存在文件draw.py中。我们有numpy ndarray和循环版本的代码,并且可以使用timeit对它们进行基准测试:

$ python -m timeit -s "import draw" "draw.calculate_loop()"
10 loops, best of 3: 95.2 msec per loop
$ python -m timeit -s "import draw" "draw.calculate_numpy()"
100 loops, best of 3: 2.11 msec per loop

正如您所看到的,numpy版本比您的循环版本快45倍。大多数情况都很好。