如果来自A

时间:2015-08-06 19:50:23

标签: python numpy

我有两个这样的数组:

import numpy as np

A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int)
A = np.reshape(A, (2,2,3))
B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5])
B = np.reshape(B, (2,2,3))

print(repr(A))
# array([[[100, 100,   3],
#         [  0,   0,   0]],

#        [[  0,   0,   0],
#         [100,   3,   5]]])

print(repr(B))
# array([[[  3,   6,   2],
#         [  6,   3,   2]],

#        [[100,   3,   2],
#         [100, 100,   5]]])

我要做的是从B中选择2x3切片,其中至少有一个值是> 10.如果不满足这个条件,我想要A的相应切片,如下所示:

# desired result 
out = np.array([100, 100, 3, 0, 0, 0, 100, 3, 2, 100, 100, 5])
out = np.reshape(out, (2,2,3))

print(repr(out))
# array([[[100, 100,   3],
#         [  0,   0,   0]],

#        [[100,   3,   2],
#         [100, 100,   5]]])

我可以找到我想要的指数:

filt = ~np.all(B < 10, axis=2)

但我不确定如何提取它们。我已经想出了这个可怕的黑客:

A2 = np.reshape(A, (4,3))
B2 = np.reshape(B, (4,3))
filt2 = np.reshape(filt, 4)

res2 = np.array([[B2[i] if filt2[i] else A2[i] for i in range(0,4)]])
res = np.reshape(res2, (2,2,3))
np.all(res == out)
Out[88]: True

这可能是一个更直接的方式,我怀疑它的NumPy选择,但我还没有想出如何使尺寸合适。思考?

2 个答案:

答案 0 :(得分:1)

import numpy as np
A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int)
A = np.reshape(A, (2,2,3))
B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5])
B = np.reshape(B, (2,2,3))

B[B<10] = A[B<10]
# out = B

使用numpy切片,您可以轻松地比较和替换大小匹配的数组之间的值。我希望这就是你想要的。

答案 1 :(得分:1)

您可以使用np.where

print(np.where(np.any(B > 10, axis=2)[..., None], B,  A))

# [[[100 100   3]
#   [  0   0   0]]

#  [[100   3   2]
#   [100 100   5]]]

np.any(B > 10, axis=2)相当于您的filt索引。由于您减少了最后一个轴,它将产生(2, 2)数组,而AB都是(2, 2, 3),因此np.where(np.any(B > 10, axis=2), B, A)会引发索引错误。

幸运的是,np.where支持broadcasting,因此您可以通过使用None建立索引来插入大小为1的新最终轴,np.where将有效地将其视为由(2, 2, 3)索引组成的filt数组重复3次。您可以通过将keepdims=True传递给np.any来保留单身最终维度,从而达到同样的效果:

np.where(np.any(B > 10, axis=2,  keepdims=1), B,  A)