尝试根据条件拆分numpy数组。过滤器必须使用split_column及其split_value并将数组分为两部分,其中一部分包含给定split_column上所有行<= split_value的子数组。
即鉴于
a = np.array([[5, 'hi', 23],
[4, 'we', 15],
[3, 'me', 10],
[2, 'be', 67],
[1, 'it', 100]])
split_column = 0
split_value = 3
预期输出为
[[3, 'me', 10],
[2, 'be', 67],
[1, 'it', 100]]
我尝试了此解决方案a[a[:, split_column] <= split_value]
,但仅当所有元素都是数字时,它才有效。
对于numpy数组中的混合类型(如上所示),我得到
TypeError:“ numpy.ndarray”和“ int”的实例之间不支持“ <=”
在a[a[:, split_column] <= str(split_value)]
中使用str()并不是解决方案,因为10 <= 3变为true,这是不正确的。对于column(1),我需要str比较,但对于其他列,它应该是数值比较。
我们该如何在numpy中执行此操作,或者在比较之前必须遍历所有元素检查类型?
答案 0 :(得分:1)
使用type
将列转换为所需的numpy.array.astype
:
a[a[:,0].astype(int) <= 3]
array([['3', 'me', '10'],
['2', 'be', '67'],
['1', 'it', '100']], dtype='<U11')