从N维Numpy数组中仅获取非零子数组

时间:2018-05-13 19:56:22

标签: python arrays numpy

我有numpy array 'arr' shape (1756020, 28, 28, 4)。 基本上'arr'1756020shape (28,28,4)小数组。 1756020数组967210中的所有数据都是零{&0};并且788810具有所有非零值。我想删除所有967210'全部为零'小阵列。我使用条件arr[i]==0.any()编写了一个if else循环,但这需要花费很多时间。有没有更好的方法呢?

2 个答案:

答案 0 :(得分:5)

向量化逻辑的一种方法是将numpy.any与包含未经测试维度的axis的元组参数一起使用。

# set up 4d array of ones
A = np.ones((5, 3, 3, 4))

# make second of shape (3, 3, 4) = 0
A[1] = 0  # or A[1, ...] = 0; or A[1, :, :, :] = 0

# find out which are non-zero
res = np.any(A, axis=(1, 2, 3))

print(res)

[True False True True True]

此功能在numpy v0.17向上提供。根据{{​​3}}:

  

无或int或元组,可选

     

如果这是一个整数元组,则会对多个轴执行缩减,   而不是像以前那样代替单轴或所有轴。

答案 1 :(得分:1)

我制作了一个你提到的大小的测试脚本。使用我的计算机,数组创建(内存错误,如果浮动,这就是为什么布尔)和选择很慢,但找到零似乎相当快:

if __name__ == '__main__':
    arr = np.ones((1756020, 28, 28, 4), dtype=bool)
    for i in range(0,1756020,2):
        arr[i] = 0
    print(arr[:5])
    s = arr.shape
    t0 = time.time()
    arr2 = arr.reshape((s[0], np.prod(s[1:])))
    ok = np.any(arr2, axis=1)
    print(time.time()-t0)
    arr_clean = arr2[ok]
    print(time.time()-t0)
    arr_clean = arr_clean.reshape((np.sum(ok), *s[1:]))
    print(time.time()-t0)
    print('end')

输出:

0.4846000671386719#零快速查找

29.750200271606445#删除零很慢

29.797000408172607#重塑原始形状[1:]很快