我有一个这样的数组:tmp.shape = (128, 64, 64)
我正在计算128
轴上的所有零,如下所示:
nonzeros = np.count_nonzero(tmp, axis=0) // shape = (64, 64)
我有一个数组c.shape = (64, 64)
现在我想沿着128轴将c的值添加到tmp,但前提是tmp的值是> 0:
for i in range(tmp.shape[0]):
axis1 = tmp[i,:,:]
tmp[i, :, :] += (c / nonzeros) // only if tmp[i, :, :] > 0
这可以在短时间内完成吗?或者我必须使用多个循环? 我希望任何人都可以建议没有其他循环的解决方案
这样的事情不起作用:
tmp[i, tmp > 0.0, tmp > 0.0] += (c / nonzeros)
IndexError:数组索引太多
LONG VERSION
for i in range(tmp.shape[0]):
for x in range(tmp.shape[1]):
for y in range(tmp.shape[2]):
pixel = tmp[i, x, y]
if pixel != 0:
pixel += (c[x,y] / nonzeros[x,y])
答案 0 :(得分:0)
您可以使用np.where
和广播。修复你的示例代码后(添加到像素不会修改tmp),
def fast(tmp, c, nonzeros):
return tmp + np.where(tmp > 0, c/nonzeros, 0)
给了我
In [6]: tmp = np.random.randint(0, 5, (128, 64, 64)).astype(float)
...: c = np.random.randint(10, 15, (64, 64)).astype(float)
...: nonzeros = np.count_nonzero(tmp, axis=0)
...:
In [7]: %time slow_result = slow(tmp, c, nonzeros)
CPU times: user 488 ms, sys: 16 ms, total: 504 ms
Wall time: 553 ms
In [8]: %time fast_result = fast(tmp, c, nonzeros)
CPU times: user 4 ms, sys: 4 ms, total: 8 ms
Wall time: 16.4 ms
In [9]: np.allclose(slow_result, fast_result)
Out[9]: True
或者,您通常可以使用乘法替换np.where
,例如tmp + (tmp > 0) * (c/nonzeros)
。
修改代码以防止非零元素为零的情况留给读者练习。 ; - )
答案 1 :(得分:0)
你基本上是在广播的c/nonzeros
中加入tmp数组,除了在tmp元素为零的地方。因此,一种方法是预先存储0s
的掩码,添加c/nonzeros
,最后使用掩码重置tmp
元素。
因此,实施将是 -
mask = tmp==0
tmp+= c/nonzeros
tmp[mask] = 0
运行时测试
方法 -
# @DSM's soln
def fast(tmp, c, nonzeros):
return tmp + np.where(tmp > 0, c/nonzeros, 0)
# Proposed in this post
def fast2(tmp, c, nonzeros):
mask = tmp==0
tmp+= c/nonzeros
tmp[mask] = 0
return tmp
计时 -
In [341]: # Setup inputs
...: M,N = 128,64
...: tmp = np.random.randint(0,10,(M,N,N)).astype(float)
...: c = np.random.rand(N,N)*100
...: nonzeros = np.count_nonzero(tmp, axis=0)
...:
...: # Make copies for testing as input would be edited with the approaches
...: tmp1 = tmp.copy()
...: tmp2 = tmp.copy()
...:
In [342]: %timeit fast(tmp1, c, nonzeros)
100 loops, best of 3: 2.22 ms per loop
In [343]: %timeit fast2(tmp2, c, nonzeros)
1000 loops, best of 3: 1.61 ms per loop
更短的替代
如果你正在寻找一个紧凑的代码,这是另一个使用non-0s
的掩码与c/nonzeros
进行广播的逐元素乘法并添加到tmp
,因此有一个 - 衬里解决方案,如此 -
tmp += (tmp!=0)*(c/nonzeros)
注意:为避免按0
进行划分,我们可以使用nonzeros
以外的任何内容编辑0s
0
,1
1}}然后使用发布的方法,如此 -
nonzeros = np.where(nonzeros > 0, nonzeros, 1)