Numpy apply_along_axis错误的dtype感染

时间:2017-09-07 12:46:15

标签: python numpy

使用NumPy时遇到以下问题:

代码:

import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)

输出:

['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!

我可以看到NumPy以某种方式从函数返回的第一个值推断数据类型。我提出了以下解决方法 - 从函数返回NumPy数组,使用明确声明的dtype而不是string,并重新整形结果:

def get_label_2(x):
    if x.sum() <= 10:
        return np.array(['SMALL'], dtype='|S5')
    else:
        return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])

你知道更优雅的解决方案吗?

2 个答案:

答案 0 :(得分:1)

您可以使用Columns("M:M").Replace What:="\=", Replacement:="=", LookAt:=xlPart

np.where

在一个功能中:

arr1 = np.array([[1, 2], [30, 40]])
arr2 = np.array([[30, 40], [1, 2]])

print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG'))
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG'))
['SMALL' 'BIG']
['BIG' 'SMALL']

答案 1 :(得分:0)

apply_along_axis不是一个优雅的解决方案;它方便,但不快。基本上它确实

In [277]: get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])])
Out[279]: 
array(['BIG', 'SMALL'],
      dtype='<U5')
In [280]: res = np.zeros((2,),dtype='S5')
In [281]: arr = np.array([[30,40],[1,2]])
In [282]: for i in range(2):
     ...:     res[i] = get_label(arr[i,:])
     ...:     
In [283]: res
Out[283]: 
array([b'BIG', b'SMALL'],
      dtype='|S5')

除了概括形状并推导出res dtype。

对行进行简单的“迭代”操作&#39;像这样的情况你也可以这样做:

In [278]: np.array([get_label(row) for row in np.array([[1,2],[30,40]])])
Out[278]: 
array(['SMALL', 'BIG'],
      dtype='<U5')
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])])
Out[279]: 
array(['BIG', 'SMALL'],
      dtype='<U5')

优雅的解决方案是避免Python级别循环,显式或隐藏,使用相反编译的数组方法,例如给sum一个轴:

In [284]: arr.sum(axis=1)
Out[284]: array([70,  3])