为什么num_r1(x)和num_r2(x)键入numpy.ndarray,但num_r(t)是float类型?如何将num_r(t)类型保留为数组?
def num_r(t):
for x in t:
if x>tx:
return num_r2(x)
else:
return num_r1(x)
谢谢!
完整的例子如下
# -*- coding: utf-8 -*
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import math
from pylab import *
#### physical parameters
c = 2.998*10**10
hp = 6.626*10**-27
hb = 1.055*10**-27
kb = 1.381*10**-16
g = 6.673*10**-8
me = 9.109*10**-28
mp = 1.673*10**-24
q = 4.803*10**-10 #gausi
sigT = 6.652*10**-25
# The evolution of the characteristic frequencies
p = 2.5
E52 = 1
epsB_r = 1
epse_r = 1
D28 = 1
n1 = 1.0
nu15 = 1*10**(-5)
r014 = 1
g42 = 1
delt12 =1
g4 = g42*10**2.5
E0 = E52*10**52
eta = g4
N0 = E0/(g4*mp*c**2)
p_tx = 3**(1./3)*2**(4./3)*mp**(-1./3)*c**(-5./3)
tx = p_tx*n1**(-1./3)*eta**(-8./3)
p_num_r1 = 2**(11./2)*7**(-2)*mp**(5./2)*me**(-3)*pi**(-1./2)*q*p_tx**(-6)*2**30*3**18*10**12
p_nuc_r1 = 2**(-33./2)*3**(-4)*10**(-4)*me*mp**(-3./2)*c**(-2)*sigT**(-2)*pi**(-1./2)*q
p_Fmax_r1 = 2**(15./2)*3**(9./2)*10**30*p_tx**(-3./2)*10**(-56)*me*mp**(1./2)*c**3*sigT*q**(-1)*2**(1./2)*3**(-1)
p_num_r2 = 2**(11./2)*7**(-2)*mp**(5./2)*me**(-3)*pi**(-1./2)*q*p_tx**(54./35)*(2**5*3**3*10**2)**(-54./35)
p_nuc_r2 = 2**(-13./2)*3**2*pi**(-1./2)*me*mp**(-3./2)*c**(-2)*sigT**(-2)*q*p_tx**(-74./35)*(2**5*3**3*10**2)**(4./35)
p_Fmax_r2 = 2**(1./2)*3**(-1)*pi**(-1./2)*me*mp**(1./2)*c**3*sigT*q**(-1)*10**(-56)
num_r1 = lambda t : p_num_r1*eta**18*((p-2)/(p-1))**2*epse_r**2*epsB_r**(1./2)*n1**(5./2)*t**6*E52**(-2)
nuc_r1 = lambda t : p_nuc_r1*eta**(-4)*epsB**(-3./2)*n1**(-3./2)*t**(-2)
Fmax_r1 = lambda t : p_Fmax_r1*N0**t**(3./2)*n1*eta**6*E52**(-1./2)*D28**(-2)*epsB_r**(1./2)
num_r2 = lambda t : p_num_r2*((p-2)/(p-1))**2*n1**(-74./35)*n1**(74./105)*eta**(592./105)*E52**(-74./105)
nuc_r2 = lambda t : p_nuc_r2*eta**(172./105)*t**(4./35)*n1**(-167./210)*epsB_r**(-3./2)
Fmax_r2 = lambda t : N0*eta**(62./105)*n1**(37./210)*epsB_r**(1./2)*t**(-34./35)*D28**(-2)
def fspe(t,u):
if num_r(t)<nuc_r(t):
return np.where(u<num_r(t),(u/num_r(t))**(1./3)*Fmax_r(t),np.where(u<nuc_r(t),(u/num_r(t))**(-(p-1.)/2)*Fmax_r(t),(u/nuc_r(t))**(-p/2)*(nuc_r(t)/num_r(t))**(-(p-1.)/2)*Fmax_r(t)))
else:
return np.where(u<nuc_r(t),(u/nuc_r(t))**(1./3)*Fmax_r(t),np.where(u<num_r(t),(u/nuc_r(t))**(-1./2)*Fmax_r(t),(u/num_r(t))**(-p/2)*(num_r(t)/nuc_r(t))**(-1.2)*Fmax_r(t)))
def num_r(t):
for x in t:
if x>tx:
return num_r2(x)
else:
return num_r1(x)
def nuc_r(t):
for x in t:
if t>tx:
return nuc_r2(x)
else:
return nuc_r1(x)
def Fmax_r(t):
for x in t:
if t>tx:
return Fmax_r2(x)
else:
return Fmax_r1(x)
i= np.arange(-4,6,0.1)
t = 10**i
dnum = [math.log10(mmm) for mmm in num_r(t)]
dnuc = [math.log10(j) for j in nuc_r(t)]
nu_obs = [math.log(2.4*10**17,10) for a in i]
plt.figure('God Bless: Observable Limit')
plt.title(r'$\nu_{obs}$ and $\nu_c$ and $\nu_m$''\nComparation')
plt.xlabel('Time: log t')
plt.ylabel(r'log $\nu$')
plt.axvline(math.log10(tx))
plt.plot(i,nu_obs,'.',label=r'$\nu_{obs}$')
plt.plot(i,dnum,'D',label=r'$\nu_m$')
plt.plot(i,dnuc,'s',label=r'$\nu_c$')
plt.legend()
plt.grid(True)
plt.savefig("nu_obs.eps", dpi=120,bbox_inches='tight')
plt.show()
但是有一个错误
TypeError Traceback (most recent call last)
<ipython-input-250-c008d4ed7571> in <module>()
95 i= np.arange(-4,6,0.1)
96 t = 10**i
---> 97 dnum = [math.log10(mmm) for mmm in num_r(t)]
TypeError:'float'对象不可迭代
答案 0 :(得分:1)
你应该把你的功能写成:
def num_r_(x):
if x > tx:
return num_r2(x)
else:
return num_r1(x)
然后将其传递至np.vectorize
,将其从float
提升至float
至np.array
至np.array
num_r = np.vectorize(num_r_)
来自Efficient evaluation of a function at every cell of a NumPy array
然后当你在:
中使用它时dnum = [math.log10(mmm) for mmm in num_r(t)]
你应该这样做:
dnum = np.log10(num_r(t))
也就是说不要使用math
模块中的函数。使用np
模块中的那些,因为它们可以np.array
以及浮动。
如:
i = np.arange(-4,6,0.1)
t = 10**i
导致t
成为np.array
答案 1 :(得分:1)
所以i
是一个数组(arange
);所以t
(数学表达式为i
)。
def num_r(t):
for x in t:
if x>tx:
return num_r2(x)
else:
return num_r1(x)
您在t
上进行迭代。 x
是t
的元素。您对其进行测试并通过num_r2
或num_r1
传递,然后立即返回 。因此,只处理第一个元素t
。因此错误 - num_r
返回一个值,而不是数组。
您需要以处理num_r
的所有值的方式编写t
,而不仅仅是第一个。{1}}。一种简单粗暴的方式是
def num_r(t):
result = []
for x in t:
if x>tx:
value = num_r2(x)
else:
value = num_r1(x)
result.append(value)
# result = np.array(result)
return result
现在num_r
应该返回与t
长度相同的列表,并且可以在列表推导中使用
[math.log10(mmm) for mmm in num_r(t)]
num_r
可以写为列表理解:
[(num_r2(x) if x>tx else num_r1(x)) for x in t]
你可以让它返回一个数组而不是一个列表,但只要你在列表推导中使用它,就没有必要了。列表很好。
如果它确实返回了一个数组,那么你可以用numpy日志操作替换列表理解,例如。
np.log10(num_r(t))
如果num_r1
和num_r2
被写入,那么他们可以拿一个数组(看起来像他们一样,但我还没有测试过),你可以写
def num_r(t):
ind = t>tx
result = np.zeros_like(t)
result[ind] = num_r2(t[ind])
result[~ind] = num_r1(t[~ind])
return result
我们的想法是在t
中找到>tx
的值的掩码,并立即通过num_r2
传递所有值;同样适用于num_r1
;并在result
的正确位置收集值。结果是一个可以传递给np.log10
的数组。这应该比在t
或使用np.vectorize
上进行迭代要快得多。
我的建议中可能存在一些错误,因为我没有在脚本或解释器中测试它们。但是潜在的想法应该是正确的,并让你走上正确的道路。