我正在尝试编写一个接受浮点数或浮点数数组的函数,并使用相同的代码行来处理它们。例如,如果是浮点数,我想返回浮点数本身,如果是数组,则想返回浮点数数组的总和。像这样
def func(a):
return np.sum(a)
,并且func(1.2)
都返回1.2,而func(np.array([1.2,1.3,1.4])
都返回3.9。
答案 0 :(得分:1)
这已经起作用了,哪里出了问题?
import numpy as np
def func(a):
return np.sum(a)
print(func(np.array([1.2,2.3,3.2])))
print(func(1.2))
输出:
6.7
1.2
答案 1 :(得分:1)
您可以使用参数变平:
def func(*args):
# code to handle args
return sum(args)
现在以下内容具有相同的行为:
>>> func(3)
3
>>> func(3, 4, 5)
12
>>> func(*[3, 4, 5])
12
答案 2 :(得分:1)
确保输入为NumPy数组的通常方法是使用np.asarray()
:
import numpy as np
def func(a):
a = np.asarray(a)
return np.sum(a)
func(1.2)
# 1.2
func([1.2, 3.4])
# 4.6
func(np.array([1.2, 3.4]))
# 4.6
或者,如果要获取数组的len()
,请确保其至少为一维,请使用np.atleast_1d()
:
def func(a):
a = np.atleast_1d(a)
return a.shape[0]
func(1.2)
# 1
func([1.2, 3.4])
# 2
func(np.array([1.2, 3.4]))
# 2
答案 3 :(得分:0)
您可以检查输入是否为浮点数,然后在处理总和之前将其放在列表中:
def func(a):
if isinstance(a, float):
a = [a]
return np.sum(a)