在numba中的numpy.array中删除一行

时间:2018-12-03 22:06:45

标签: python numpy numba

这是我第一次在这里发布内容。我正在尝试删除numba jitclass内numpy数组内的行。我编写了以下代码以删除任何包含3的行:

>>> a = np.array([[1,2,3,4],[5,6,7,8]])

>>> a

>>> array([[1, 2, 3, 4],
       [5, 6, 7, 8]])

>>> i = np.where(a==3)

>>> i

>>> (array([0]), array([2]))

我无法使用numpy.delete()函数,因为numba不支持该函数,并且无法将None类型的值分配给该行。我所能做的就是通过以下方式将0分配给该行:

>>> a[i[0]] = 0

>>> a

>>> array([[0, 0, 0, 0],
       [5, 6, 7, 8]])

但是我想完全删除该行。

任何帮助将不胜感激。

非常感谢您。

3 个答案:

答案 0 :(得分:1)

这实际上不是一件容易的事,因为numba具有以下限制:

  • 不支持np.delete
  • 不支持axisnp.all中的np.any关键字
  • 不支持2D数组索引编制(至少不支持布尔掩码)
  • 没有或妨碍使用np.zeros(shape, dtype=np.bool)或类似功能直接创建防毒面具

但是您仍然可以采用几种方法来解决问题。我测试了一些,创建一个布尔型掩码似乎是最快,最干净的方法。

@nb.njit
def delete_workaround(arr, num):
    mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
    mask[np.where(arr == num)[0]] = False
    return arr[mask]

a = np.array([[1,2,3,4],[5,6,7,8]])

delete_workaround(a, 3)

即使仅返回一行或一个空数组,该解决方案也具有保留数组尺寸的巨大优势。这对于jitclass非常重要,因为jitclass严重依赖固定尺寸。

根据您的要求,我将向您展示一种将数组转换为列表然后返回的解决方案。由于numba中的所有python方法尚不支持反射列表,因此您必须对函数的某些部分使用包装器:

@nb.njit
def delete_lrow(arr_list, num):
    idx_list = []
    for i in range(len(arr_list)):
        if (arr_list[i] != num).all():
            idx_list.append(i)
    res_list = [arr_list[i] for i in idx_list]
    return res_list

def wrap_list_del(arr, num):
    arr_list = list(arr)
    return np.array(delete_lrow(arr_list, num))

arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)

%timeit delete_workaround(arr, 3)
# 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit wrap_list_del(arr, 3)    
# 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit delete_workaround(arr2, 3)
# 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit wrap_list_del(arr2, 3)
# 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此,如果您已经拥有数组(即使您还没有数组,但是您的数据是一致的类型),则坚持使用数组对于小型数组的速度大约快50倍,对于大型数组的速度大约快 550倍数组。 这是要记住的东西:Numpy数组用于处理数字数据! Numpy经过严格优化,可用于数字数据!如果数据类型(dtype)是常量并且没有任何超特殊要求(我几乎从未遇到过这种情况),那么将数字数据数组转换为另一种“格式”绝对没有用。 br /> 对于numba优化代码尤其如此! Numba在很大程度上依赖于numpy和常量dtypes / shapes等。如果要使用jitclass,则更多。

答案 1 :(得分:0)

欢迎来到Stacoverflow。您可以简单地使用数组切片来仅选择其中没有3行的行。 下面的代码有点精巧,基本上可以为您覆盖更多详细信息,尽管您可以使用更短的版本并删除不必要的行。密钥分配为rows_final = [x for x in range(a.shape[0]) if x not in rows3]

代码:

import numpy as np

a = np.array([[1,2,3,4],[5,6,7,8],[10,11,3,13]])

ind = np.argwhere(a==3)
rows3 = ind[0]
cols3 = ind[1]

print ("Initial Array: \n", a)
print()
print("rows, cols of a==3 : ", rows3, cols3)

rows_final = [x for x in range(a.shape[0]) if x not in rows3]
a_final = a[rows_final,:]

print()
print ("Final Rows: \n", rows_final)
print ("Final Array: \n", a_final)

输出:

Initial Array: 
 [[ 1  2  3  4]
 [ 5  6  7  8]
 [10 11  3 13]]

rows, cols of a==3 :  [0 2] [2 2]

Final Rows: 
 [1]
Final Array: 
 [[5 6 7 8]]

答案 2 :(得分:-2)

我认为您需要再次将删除分配给变量a,这对我来说很有效。尝试以下代码:

import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
print(a)
i = np.where(a==3)
a=np.delete(a, i, 0) # assign it back to the variable
print(a)