如何根据numpy中的条件拆分异构数组?

时间:2019-04-07 07:05:27

标签: python arrays python-3.x numpy

尝试根据条件拆分numpy数组。过滤器必须使用split_column及其split_value并将数组分为两部分,其中一部分包含给定split_column上所有行<= split_value的子数组。

即鉴于

 a = np.array([[5, 'hi', 23],
               [4, 'we', 15],
               [3, 'me', 10],
               [2, 'be', 67],
               [1, 'it', 100]])

split_column = 0
split_value = 3

预期输出为

     [[3, 'me', 10],
      [2, 'be', 67],
      [1, 'it', 100]]

我尝试了此解决方案a[a[:, split_column] <= split_value],但仅当所有元素都是数字时,它才有效。

对于numpy数组中的混合类型(如上所示),我得到

TypeError:“ numpy.ndarray”和“ int”的实例之间不支持“ <=”

a[a[:, split_column] <= str(split_value)]中使用str()并不是解决方案,因为10 <= 3变为true,这是不正确的。对于column(1),我需要str比较,但对于其他列,它应该是数值比较。

我们该如何在numpy中执行此操作,或者在比较之前必须遍历所有元素检查类型?

1 个答案:

答案 0 :(得分:1)

使用type转换为所需的numpy.array.astype

a[a[:,0].astype(int) <= 3]
array([['3', 'me', '10'],
       ['2', 'be', '67'],
       ['1', 'it', '100']], dtype='<U11')