numpy.apply_along_axis会截断字符串,因为它会推断出错误的dtype'<u1'

时间:2019-05-04 17:14:31

标签: python numpy numpy-ndarray

=“”

我不知道如何返回dtype U3

我要:

  1. 将my_array应用于轴

  2. 对于每一行,返回一个字符串

def my_function(x):
    return x[2]
my_array = np.array([[1,1,"A"],[1,1,"BBB"], [1,1,"CCC"]])
np.apply_along_axis(my_function, axis=1, arr=my_array)

我期望输出: array(['A', 'BBB', 'CCC'], dtype='<U3') 但实际输出是 array(['A', 'B', 'C'], dtype='<U1')

因为第一个元素('A')具有固定大小的U1,每个下一个元素都被截断为U1('BBB'->'B')。

您知道如何使用dtype U3将代码更改为字符串吗?

2 个答案:

答案 0 :(得分:0)

尝试一下(尽管可能应该有更好的方法):

import numpy as np

def my_function(x):
    return np.array(x[2], dtype='<U3')

my_array = np.array([[1,1,"A"],[1,1,"BBB"], [1,1,"CCC"]])
np.apply_along_axis(my_function, axis=1, arr=my_array)

答案 1 :(得分:0)

对于此特定用例,您可以使用切片,即

my_array[:, 2]

,并完全避免应用apply_along_axis。但是我同意从函数的第一个应用程序推断类型是很麻烦的。还有一个issue

顺便说一句:数组中的数字将转换为字符串,但这会导致<U21的类型不尽人意。如果直接将它们设置为字符串,则会得到<U3