我试图学习python - 但对OOP来说是新手。我想让函数fA,fB,fC同时在整个r-theta空间上运行,而不是一次运行一个点。我的问题是条件(r <= 1)。下面的代码非常难看(!)但是它有效。
我怎样才能使这更像pythonesque?谢谢!
(在这个简化的例子中,注意当r变为零时,(r> 1)的数学运算会发散)
from math import pi, sin, cos, exp
import numpy as np
import matplotlib.pyplot as plt
def fA(rr,th,a,b,c):
if (rr<=1):
fx = a * sin(th)
fy = b * rr * cos(th)
fz = c * rr
else:
fx = (a / rr) * sin(th)
fy = (b / rr) * cos(th)
fz = (c / rr)
return(fx,fy,fz)
def fB(rr,th,a,b,c):
if (rr<=1):
fx = b * sin(2.*th)
fy = a * rr * cos(2.*th)
fz = c * rr
else:
fx = (b / rr) * sin(2.*th)
fy = (a / rr) * cos(2.*th)
fz = c
return(fx,fy,fz)
def fC(rr,th,a,b,c):
if (rr<=1):
fx = exp(rr - 1.) * cos(th)
fy = exp(rr - 1.) * sin(th)
fz = c
else:
fx = exp(1. - rr) * cos(th)
fy = exp(1. - rr) * sin(th)
fz = c / rr
return(fx,fy,fz)
nx = 101
ny = 101
dx = 4. / (nx-1)
dy = 4. / (ny-1)
X = np.zeros((ny,nx))
Y = np.zeros((ny,nx))
for ix in range(nx):
for iy in range(ny):
X[iy,ix] = dx*(ix - (nx-1)/2)
Y[iy,ix] = dy*(iy - (ny-1)/2)
r = np.sqrt(X**2. + Y**2.)
theta = np.arctan2(Y,X)
Ax = np.zeros((ny,nx))
Ay = np.zeros((ny,nx))
Az = np.zeros((ny,nx))
Bx = np.zeros((ny,nx))
By = np.zeros((ny,nx))
Bz = np.zeros((ny,nx))
Cx = np.zeros((ny,nx))
Cy = np.zeros((ny,nx))
Cz = np.zeros((ny,nx))
for ix in range (nx):
for iy in range(ny):
Ax[iy,ix], Ay[iy,ix], Az[iy,ix] = fA(r[iy,ix], theta[iy,ix], 1.0, 1.0, 1.5)
Bx[iy,ix], By[iy,ix], Bz[iy,ix] = fB(r[iy,ix], theta[iy,ix], 1.5, 0.8, 1.0)
Cx[iy,ix], Cy[iy,ix], Cz[iy,ix] = fC(r[iy,ix], theta[iy,ix], 0.9, 1.1, 1.2)
plt.figure()
plt.subplot(3,3,1)
plt.imshow(Ax)
plt.colorbar()
plt.title('Ax')
plt.subplot(3,3,2)
plt.imshow(Ay)
plt.colorbar()
plt.title('Ay')
plt.subplot(3,3,3)
plt.imshow(Az)
plt.colorbar()
plt.title('Az')
plt.subplot(3,3,4)
plt.imshow(Bx)
plt.colorbar()
plt.title('Bx')
plt.subplot(3,3,5)
plt.imshow(By)
plt.colorbar()
plt.title('By')
plt.subplot(3,3,6)
plt.imshow(Bz)
plt.colorbar()
plt.title('Bz')
plt.subplot(3,3,7)
plt.imshow(Cx)
plt.colorbar()
plt.title('Cx')
plt.subplot(3,3,8)
plt.imshow(Cy)
plt.colorbar()
plt.title('Cy')
plt.subplot(3,3,9)
plt.imshow(Cz)
plt.colorbar()
plt.title('Cz')
plt.show()
答案 0 :(得分:1)
随机选择你的一个关系,你可以使用numpy.where而不是if ...语义:
fx = where( rr<=1, exp(rr - 1.) * cos(th), exp(1. - rr) * cos(th))
对于一组数组而言,这实际上是if / else,而不是一次一个数字。然后,您可以在没有循环的情况下执行Ax,Ay,Az = fA(...)。
您可以使用meshgrid或mgrid制作X和Y.
要不评估所有元素,可以使用切片表示法
fx = empty_like(rr)
fx[rr<=1] = exp(rr[rr<=1] - 1.) * cos(th[rr<=1])
fx[rr>1] = exp(1. - rr[rr>1]) * cos(th[rr>1]))
答案 1 :(得分:1)
像@mdurant一样,np.where
和np.meshgrid
会很有用。在这里,让我重新组织您的代码并使用numpy advanced indexing提供另一种避免Python循环的方法:
import sys
from math import pi, sin, cos, exp
import numpy as np
import matplotlib.pyplot as plt
def _generate_coordinate(nx, ny):
"""
Generate coordinate data points in a function to prevent namespace
pollution.
"""
dx = 4. / (nx-1)
dy = 4. / (ny-1)
X = np.zeros((ny,nx))
Y = np.zeros((ny,nx))
for ix in range(nx):
for iy in range(ny):
X[iy,ix] = dx*(ix - (nx-1)/2)
Y[iy,ix] = dy*(iy - (ny-1)/2)
return np.sqrt(X**2 + Y**2), np.arctan2(Y,X)
nx = ny = 101
r, theta = _generate_coordinate(101, 101)
def calculate_numpy():
# Helper methods for vector-based calculator.
def fA(rr, th, a, b, c):
# Calculate every value because mulplication doesn't give NaN.
arrx = a * np.sin(th)
arry = b * rr * np.cos(th)
arrz = c * rr
# Override value with a certain condition.
slct = rr > 1
rr = rr[slct]
th = th[slct]
arrx[slct] = a / rr * np.sin(th)
arry[slct] = b / rr * np.cos(th)
arrz[slct] = c / rr
return arrx, arry, arrz
def fB(rr, th, a, b, c):
# Calculate every value because mulplication doesn't give NaN.
arrx = b * np.sin(2.*th)
arry = a * rr * np.cos(2.*th)
arrz = c * rr
# Override value with a certain condition.
slct = rr > 1
rr = rr[slct]
th = th[slct]
arrx[slct] = b / rr * np.sin(2.*th)
arry[slct] = a / rr * np.cos(2.*th)
arrz[slct] = c
return arrx, arry, arrz
def fC(rr,th,a,b,c):
# Calculate every value because mulplication doesn't give NaN.
arrx = np.exp(rr-1) * np.cos(th)
arry = np.exp(rr-1) * np.sin(th)
arrz = np.empty_like(rr)
arrz.fill(c)
# Override value with a certain condition.
slct = rr > 1
rr = rr[slct]
th = th[slct]
arrx[slct] = np.exp(1.-rr) * np.cos(th)
arry[slct] = np.exp(1.-rr) * np.sin(th)
arrz[slct] = c / rr
return arrx, arry, arrz
# Carry out calculation.
Ax, Ay, Az = fA(r, theta, 1.0, 1.0, 1.5)
Bx, By, Bz = fB(r, theta, 1.5, 0.8, 1.0)
Cx, Cy, Cz = fC(r, theta, 0.9, 1.1, 1.2)
return Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz
def calculate_loop():
# Helper methods for loop calculator.
def fA(rr,th,a,b,c):
if (rr<=1):
fx = a * sin(th)
fy = b * rr * cos(th)
fz = c * rr
else:
fx = (a / rr) * sin(th)
fy = (b / rr) * cos(th)
fz = (c / rr)
return(fx,fy,fz)
def fB(rr,th,a,b,c):
if (rr<=1):
fx = b * sin(2.*th)
fy = a * rr * cos(2.*th)
fz = c * rr
else:
fx = (b / rr) * sin(2.*th)
fy = (a / rr) * cos(2.*th)
fz = c
return(fx,fy,fz)
def fC(rr,th,a,b,c):
if (rr<=1):
fx = exp(rr - 1.) * cos(th)
fy = exp(rr - 1.) * sin(th)
fz = c
else:
fx = exp(1. - rr) * cos(th)
fy = exp(1. - rr) * sin(th)
fz = c / rr
return(fx,fy,fz)
# Create buffer arrays for loops.
Ax = np.zeros((ny,nx))
Ay = np.zeros((ny,nx))
Az = np.zeros((ny,nx))
Bx = np.zeros((ny,nx))
By = np.zeros((ny,nx))
Bz = np.zeros((ny,nx))
Cx = np.zeros((ny,nx))
Cy = np.zeros((ny,nx))
Cz = np.zeros((ny,nx))
# Carry out calculation with Python loops. This is slow.
for ix in range (nx):
for iy in range(ny):
Ax[iy,ix], Ay[iy,ix], Az[iy,ix] = fA(r[iy,ix], theta[iy,ix], 1.0, 1.0, 1.5)
Bx[iy,ix], By[iy,ix], Bz[iy,ix] = fB(r[iy,ix], theta[iy,ix], 1.5, 0.8, 1.0)
Cx[iy,ix], Cy[iy,ix], Cz[iy,ix] = fC(r[iy,ix], theta[iy,ix], 0.9, 1.1, 1.2)
return Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz
def main():
calculate = globals()["calculate_" + sys.argv[1]]
Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz = calculate()
plt.figure()
plt.subplot(3,3,1)
plt.imshow(Ax)
plt.colorbar()
plt.title('Ax')
plt.subplot(3,3,2)
plt.imshow(Ay)
plt.colorbar()
plt.title('Ay')
plt.subplot(3,3,3)
plt.imshow(Az)
plt.colorbar()
plt.title('Az')
plt.subplot(3,3,4)
plt.imshow(Bx)
plt.colorbar()
plt.title('Bx')
plt.subplot(3,3,5)
plt.imshow(By)
plt.colorbar()
plt.title('By')
plt.subplot(3,3,6)
plt.imshow(Bz)
plt.colorbar()
plt.title('Bz')
plt.subplot(3,3,7)
plt.imshow(Cx)
plt.colorbar()
plt.title('Cx')
plt.subplot(3,3,8)
plt.imshow(Cy)
plt.colorbar()
plt.title('Cy')
plt.subplot(3,3,9)
plt.imshow(Cz)
plt.colorbar()
plt.title('Cz')
plt.show()
if __name__ == '__main__':
main()
函数calculate_numpy()
是我演示高级索引的地方。如果您想完全避免重复计算,则需要像calculate_loop()
中所做的那样创建缓冲区。但我会说在运行时方面,重复计算是可以的。
假设程序保存在文件draw.py
中。我们有numpy ndarray和循环版本的代码,并且可以使用timeit
对它们进行基准测试:
$ python -m timeit -s "import draw" "draw.calculate_loop()"
10 loops, best of 3: 95.2 msec per loop
$ python -m timeit -s "import draw" "draw.calculate_numpy()"
100 loops, best of 3: 2.11 msec per loop
正如您所看到的,numpy版本比您的循环版本快45倍。大多数情况都很好。