import numpy as np
# The 3D arrays have the axis: Z, X, Y
arr_keys = np.random.rand(20, 5, 5)
arr_vals = np.random.rand(20, 5, 5)
arr_idx = np.random.rand(5, 5)
对于arr_idx
中的每个网格单元,我想在arr_keys
中查找最接近它的值的Z位置(但具有相同的X,Y位置),并在arr_vals
数组中的对应位置。有没有一种方法可以不使用嵌套的for循环?
因此,如果arr_idx
的X = 0,Y = 0处的值为0.5,我想找到最接近X = 0,Y = 0,Z范围为0到10的数字
arr_keys
中的值,然后我想使用该数字的Z位置(我们称其为Z_prime)来查找arr_vals
中的值(Z_prime,X = 0,Y = 0)
答案 0 :(得分:2)
这是为其创建np.take_along_axis
的问题的类型:
# shape (20, 5, 5)
diff = np.abs(arr_idx - arr_keys)
# argmin(..., keepdims=True) doesn't exist yet - this emulates it
# shape (1, 5, 5)
inds = np.expand_dims(np.argmin(diff, axis=0), axis=0)
# shape (1, 5, 5)
res = np.take_along_axis(arr_vals, inds, axis=0)
# shape (5, 5)
res = res.squeeze(axis=0)
答案 1 :(得分:1)
我认为这可能可行:将轴滚动到正确的方向,找到5x5 X,Y值中每个(绝对)最小值的值的索引,并从{{1}中获取相应的Z值}:
arr_vals
要对此进行测试,请尝试使用idx = np.argmin(np.abs(np.rollaxis(arr_keys,0,3) - arr_idx[:,:,None]), axis=2)
i,j = np.ogrid[:5,:5]
arr_vals[idx[i,j],i,j]
情况:
(3,2,2)
给予:
In [15]: arr_keys
Out[15]:
array([[[ 0.19681533, 0.26897784],
[ 0.60469711, 0.09273087]],
[[ 0.04961604, 0.3460404 ],
[ 0.88406912, 0.41284309]],
[[ 0.46298201, 0.33809574],
[ 0.99604152, 0.4836324 ]]])
In [16]: arr_vals
Out[16]:
array([[[ 0.88865681, 0.88287688],
[ 0.3128103 , 0.24188022]],
[[ 0.23947227, 0.57913325],
[ 0.85768064, 0.91701097]],
[[ 0.78105669, 0.84144339],
[ 0.81071981, 0.69217687]]])
In [17]: arr_idx
Out[17]:
array([[[ 0.31352609],
[ 0.75462329]],
[[ 0.44445286],
[ 0.97086161]]])
答案 2 :(得分:1)
我认为@xnx的答案很好。我的更长,但我还是会发布;)。
另外,请注意:NumPy通过vectorizing操作有效地处理大型多维数组。因此,我建议尽可能避免for
循环。无论您要寻找的任务是什么,通常都有一种避免循环的方法。
arr_keys = np.split(arr_keys, 20)
arr_keys = np.stack(arr_keys, axis=-1)[0]
arr_vals = np.split(arr_vals, 20)
arr_vals = np.stack(arr_vals, axis=-1)[0]
arr_idx = np.expand_dims(arr_idx, axis=-1)
difference = np.abs(arr_keys - arr_idx)
minimum = np.argmin(difference, axis=-1)
result = np.take_along_axis(arr_vals, np.expand_dims(minimum, axis=-1), axis=-1)
result = np.squeeze(result, axis=-1)
答案 3 :(得分:0)
比已经发布的解决方案稍微冗长,但更易于理解。
import numpy as np
# The 3D arrays have the axis: Z, X, Y
arr_keys = np.random.rand(20, 5, 5)
arr_vals = np.random.rand(20, 5, 5)
arr_idx = np.random.rand(5, 5)
arr_idx = arr_idx[np.newaxis, :, :]
dist = np.abs(arr_idx - arr_keys)
dist_ind = np.argmin(dist, axis=0)
x = np.arange(0, 5, 1)
y = np.arange(0, 5, 1)
xx, yy = np.meshgrid(x, y)
res = arr_vals[dist_ind, yy, xx]