清除返回单元素Numpy数组的方法

时间:2017-10-25 07:14:53

标签: python numpy

是否有一种干净的方法来编写返回单元素numpy数组作为元素本身的函数?

我想说我想要一个简单的方形函数,我希望我的返回值与输入的dtype相同。我可以这样写:

def foo(x):
    result = np.square(x)
    if len(result) == 1:
        return result[0]
    return result

def foo(x):
    if len(x) == 1:
        return x**2
    return np.square(x)

有更简单的方法吗?这样我可以将这个函数用于标量和数组吗?

我知道我可以直接检查输入的dtype并使用IF语句使其工作,但是有更简洁的方法吗?

2 个答案:

答案 0 :(得分:2)

我不确定我是否完全理解这个问题,但是这样的事情会有所帮助吗?

def square(x):
    if 'numpy' in str(type(x)):
        return np.square(x)
    else:
        if isinstance(x, list):
            return list(np.square(x))
        if isinstance(x, int):
            return int(np.square(x))
        if isinstance(x, float):
            return float(np.square(x))

我定义了一些测试用例:

np_array_one = np.array([3.4])
np_array_mult = np.array([3.4, 2, 6])
int_ = 5
list_int = [2, 4, 2.9]
float_ = float(5.3)
list_float = [float(4.5), float(9.1), float(7.5)]

examples = [np_array_one, np_array_mult, int_, list_int, float_, list_float]

所以我们可以看到函数的行为方式。

for case in examples:
    print 'Input type: {}.'.format(type(case))
    out = square(case)
    print out
    print 'Output type: {}'.format(type(out))
    print '-----------------'

输出:

Input type: <type 'numpy.ndarray'>.
[ 11.56]
Output type: <type 'numpy.ndarray'>
-----------------
Input type: <type 'numpy.ndarray'>.
[ 11.56   4.    36.  ]
Output type: <type 'numpy.ndarray'>
-----------------
Input type: <type 'int'>.
25
Output type: <type 'int'>
-----------------
Input type: <type 'list'>.
[4.0, 16.0, 8.4100000000000001]
Output type: <type 'list'>
-----------------
Input type: <type 'float'>.
28.09
Output type: <type 'float'>
-----------------
Input type: <type 'list'>.
[20.25, 82.809999999999988, 56.25]
Output type: <type 'list'>
-----------------

从测试用例中,输入和输出始终相同。但是,功能并不是很干净。

我在SO中使用了这个question的一些代码。

答案 1 :(得分:1)

我认为你需要一个很好的理由想要那个。 (你能解释一下为什么需要这个吗?)

此函数的所有客户端都必须检查结果是数组还是单个元素,或者必须将其转换为数组。 通常,如果迭代数组的所有元素,即使它只是一个元素,也可以获得非常优雅的代码。

除非它总是必须是一个单独的元素(这是一个转换函数),但是return语句应该在空/长数组上抛出异常。

除此之外,您拥有的代码完全可以理解/可读。任何聪明的技巧来改善&#39;每当你或同事必须阅读它时,这将是一种精神负担。

- 编辑

我明白你的观点。可能你已经遇到了不允许len(1)的问题(int / float没有len()),所以你可以对输入参数进行类型检查。例如。

if (type(x) == list) ...