python为什么数据类型被def函数改变了?

时间:2016-03-14 02:29:38

标签: python numpy

为什么num_r1(x)和num_r2(x)键入numpy.ndarray,但num_r(t)是float类型?如何将num_r(t)类型保留为数组?

def num_r(t):
    for x in t:
        if x>tx:
            return num_r2(x)
        else:
            return num_r1(x)

谢谢!

完整的例子如下

# -*- coding: utf-8 -*
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import math
from pylab import *

#### physical parameters
c = 2.998*10**10
hp = 6.626*10**-27
hb = 1.055*10**-27
kb = 1.381*10**-16
g = 6.673*10**-8
me = 9.109*10**-28
mp = 1.673*10**-24
q = 4.803*10**-10  #gausi
sigT = 6.652*10**-25

# The evolution of the characteristic frequencies
p = 2.5
E52 = 1
epsB_r = 1
epse_r = 1

D28 = 1 
n1 = 1.0
nu15 = 1*10**(-5)
r014 = 1
g42 = 1
delt12 =1
g4 = g42*10**2.5

E0 = E52*10**52
eta = g4
N0 = E0/(g4*mp*c**2)



p_tx = 3**(1./3)*2**(4./3)*mp**(-1./3)*c**(-5./3)
tx = p_tx*n1**(-1./3)*eta**(-8./3)

p_num_r1 = 2**(11./2)*7**(-2)*mp**(5./2)*me**(-3)*pi**(-1./2)*q*p_tx**(-6)*2**30*3**18*10**12

p_nuc_r1 = 2**(-33./2)*3**(-4)*10**(-4)*me*mp**(-3./2)*c**(-2)*sigT**(-2)*pi**(-1./2)*q

p_Fmax_r1 = 2**(15./2)*3**(9./2)*10**30*p_tx**(-3./2)*10**(-56)*me*mp**(1./2)*c**3*sigT*q**(-1)*2**(1./2)*3**(-1)

p_num_r2 = 2**(11./2)*7**(-2)*mp**(5./2)*me**(-3)*pi**(-1./2)*q*p_tx**(54./35)*(2**5*3**3*10**2)**(-54./35)

p_nuc_r2 = 2**(-13./2)*3**2*pi**(-1./2)*me*mp**(-3./2)*c**(-2)*sigT**(-2)*q*p_tx**(-74./35)*(2**5*3**3*10**2)**(4./35)

p_Fmax_r2 = 2**(1./2)*3**(-1)*pi**(-1./2)*me*mp**(1./2)*c**3*sigT*q**(-1)*10**(-56)



num_r1 = lambda t : p_num_r1*eta**18*((p-2)/(p-1))**2*epse_r**2*epsB_r**(1./2)*n1**(5./2)*t**6*E52**(-2)

nuc_r1 = lambda t : p_nuc_r1*eta**(-4)*epsB**(-3./2)*n1**(-3./2)*t**(-2)

Fmax_r1 = lambda t : p_Fmax_r1*N0**t**(3./2)*n1*eta**6*E52**(-1./2)*D28**(-2)*epsB_r**(1./2)

num_r2 = lambda t : p_num_r2*((p-2)/(p-1))**2*n1**(-74./35)*n1**(74./105)*eta**(592./105)*E52**(-74./105)

nuc_r2 = lambda t : p_nuc_r2*eta**(172./105)*t**(4./35)*n1**(-167./210)*epsB_r**(-3./2)

Fmax_r2 = lambda t : N0*eta**(62./105)*n1**(37./210)*epsB_r**(1./2)*t**(-34./35)*D28**(-2)

def fspe(t,u):
  if num_r(t)<nuc_r(t):
    return np.where(u<num_r(t),(u/num_r(t))**(1./3)*Fmax_r(t),np.where(u<nuc_r(t),(u/num_r(t))**(-(p-1.)/2)*Fmax_r(t),(u/nuc_r(t))**(-p/2)*(nuc_r(t)/num_r(t))**(-(p-1.)/2)*Fmax_r(t)))
  else:
    return np.where(u<nuc_r(t),(u/nuc_r(t))**(1./3)*Fmax_r(t),np.where(u<num_r(t),(u/nuc_r(t))**(-1./2)*Fmax_r(t),(u/num_r(t))**(-p/2)*(num_r(t)/nuc_r(t))**(-1.2)*Fmax_r(t)))


def num_r(t):
 for x in t:
  if x>tx:
    return num_r2(x)
  else:
    return num_r1(x)
def nuc_r(t):
 for x in t:
  if t>tx:
    return nuc_r2(x)
  else:
    return nuc_r1(x)
def Fmax_r(t):
 for x in t:
  if t>tx:
    return Fmax_r2(x)
  else:
    return Fmax_r1(x)


i= np.arange(-4,6,0.1)
t = 10**i
dnum   = [math.log10(mmm) for mmm in num_r(t)]
dnuc   = [math.log10(j) for j in nuc_r(t)]
nu_obs = [math.log(2.4*10**17,10) for a in i]
plt.figure('God Bless: Observable Limit')
plt.title(r'$\nu_{obs}$ and $\nu_c$ and $\nu_m$''\nComparation')
plt.xlabel('Time: log t')
plt.ylabel(r'log $\nu$')
plt.axvline(math.log10(tx))
plt.plot(i,nu_obs,'.',label=r'$\nu_{obs}$')
plt.plot(i,dnum,'D',label=r'$\nu_m$')
plt.plot(i,dnuc,'s',label=r'$\nu_c$')
plt.legend()
plt.grid(True)
plt.savefig("nu_obs.eps", dpi=120,bbox_inches='tight')
plt.show()

但是有一个错误

TypeError  Traceback (most recent call last)
<ipython-input-250-c008d4ed7571> in <module>()
     95     i= np.arange(-4,6,0.1)
     96     t = 10**i
--->     97     dnum   = [math.log10(mmm) for mmm in num_r(t)]
  

TypeError:'float'对象不可迭代

2 个答案:

答案 0 :(得分:1)

你应该把你的功能写成:

def num_r_(x):
  if x > tx:
      return num_r2(x)
  else:
      return num_r1(x)

然后将其传递至np.vectorize,将其从float提升至floatnp.arraynp.array

num_r = np.vectorize(num_r_)

来自Efficient evaluation of a function at every cell of a NumPy array

然后当你在:

中使用它时
dnum = [math.log10(mmm) for mmm in num_r(t)]

你应该这样做:

dnum = np.log10(num_r(t))

也就是说不要使用math模块中的函数。使用np模块中的那些,因为它们可以np.array以及浮动。

如:

i = np.arange(-4,6,0.1)
t = 10**i

导致t成为np.array

答案 1 :(得分:1)

所以i是一个数组(arange);所以t(数学表达式为i)。

def num_r(t):
 for x in t:
  if x>tx:
    return num_r2(x)
  else:
    return num_r1(x)

您在t上进行迭代。 xt的元素。您对其进行测试并通过num_r2num_r1传递,然后立即返回 。因此,只处理第一个元素t。因此错误 - num_r返回一个值,而不是数组。

您需要以处理num_r的所有值的方式编写t,而不仅仅是第一个。{1}}。一种简单粗暴的方式是

def num_r(t):
    result = []
    for x in t:
       if x>tx:
          value = num_r2(x)
       else:
          value = num_r1(x)
       result.append(value)
    # result = np.array(result)
    return result

现在num_r应该返回与t长度相同的列表,并且可以在列表推导中使用

[math.log10(mmm) for mmm in num_r(t)]

num_r可以写为列表理解:

[(num_r2(x) if x>tx else num_r1(x)) for x in t]

你可以让它返回一个数组而不是一个列表,但只要你在列表推导中使用它,就没有必要了。列表很好。

如果它确实返回了一个数组,那么你可以用numpy日志操作替换列表理解,例如。

np.log10(num_r(t))

如果num_r1num_r2被写入,那么他们可以拿一个数组(看起来像他们一样,但我还没有测试过),你可以写

def num_r(t):
     ind = t>tx
     result = np.zeros_like(t)
     result[ind] = num_r2(t[ind])
     result[~ind] = num_r1(t[~ind])
     return result

我们的想法是在t中找到>tx的值的掩码,并立即通过num_r2传递所有值;同样适用于num_r1;并在result的正确位置收集值。结果是一个可以传递给np.log10的数组。这应该比在t或使用np.vectorize上进行迭代要快得多。

我的建议中可能存在一些错误,因为我没有在脚本或解释器中测试它们。但是潜在的想法应该是正确的,并让你走上正确的道路。