我有numpy
array
'arr'
shape (1756020, 28, 28, 4)
。
基本上'arr'
有1756020
个shape (28,28,4)
小数组。 1756020
数组967210
中的所有数据都是零{&0};并且788810
具有所有非零值。我想删除所有967210
'全部为零'小阵列。我使用条件arr[i]==0.any()
编写了一个if else循环,但这需要花费很多时间。有没有更好的方法呢?
答案 0 :(得分:5)
向量化逻辑的一种方法是将numpy.any
与包含未经测试维度的axis
的元组参数一起使用。
# set up 4d array of ones
A = np.ones((5, 3, 3, 4))
# make second of shape (3, 3, 4) = 0
A[1] = 0 # or A[1, ...] = 0; or A[1, :, :, :] = 0
# find out which are non-zero
res = np.any(A, axis=(1, 2, 3))
print(res)
[True False True True True]
此功能在numpy
v0.17向上提供。根据{{3}}:
轴:无或int或元组,可选
如果这是一个整数元组,则会对多个轴执行缩减, 而不是像以前那样代替单轴或所有轴。
答案 1 :(得分:1)
我制作了一个你提到的大小的测试脚本。使用我的计算机,数组创建(内存错误,如果浮动,这就是为什么布尔)和选择很慢,但找到零似乎相当快:
if __name__ == '__main__':
arr = np.ones((1756020, 28, 28, 4), dtype=bool)
for i in range(0,1756020,2):
arr[i] = 0
print(arr[:5])
s = arr.shape
t0 = time.time()
arr2 = arr.reshape((s[0], np.prod(s[1:])))
ok = np.any(arr2, axis=1)
print(time.time()-t0)
arr_clean = arr2[ok]
print(time.time()-t0)
arr_clean = arr_clean.reshape((np.sum(ok), *s[1:]))
print(time.time()-t0)
print('end')
输出:
0.4846000671386719#零快速查找
29.750200271606445#删除零很慢
29.797000408172607#重塑原始形状[1:]很快