是否可以对多维数组的查找表使用ndarray.choose(或其他ndarray机制),以便输入数组的每个元素在输出数组中生成多个元素?例如。
input = array([[0, 1],
[2, 3]])
subs = array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]])
output = array([[ 0, 1, 4, 5],
[ 2, 3, 6, 7],
[ 8, 9, 12, 13],
[10, 11, 14, 15]])
答案 0 :(得分:1)
在此示例中,您可以按照以下步骤进行操作:
import numpy as np
input_ = np.array([[0, 1],
[2, 3]])
subs = np.array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]])
res = subs[input_].transpose((0, 2, 1, 3)).reshape((4, 4))
print(res)
# [[ 0 1 4 5]
# [ 2 3 6 7]
# [ 8 9 12 13]
# [10 11 14 15]]
编辑:
一种更通用的解决方案,支持更多维度以及具有不同维度数量的输入和替换:
import numpy as np
def expand_from(input_, subs):
input_= np.asarray(input_)
subs = np.asarray(subs)
# Take from subs according to input
res = subs[input_]
# Input dimensions
in_dims = input_.ndim
# One dimension of subs is for indexing
s_dims = subs.ndim - 1
# Dimensions that correspond to each other on output
num_matched = min(in_dims, s_dims)
matched_dims = [(i, in_dims + i) for i in range(num_matched)]
# Additional dimensions if there are any
if in_dims > s_dims:
extra_dims = list(range(num_matched, in_dims))
else:
extra_dims = list(range(2 * num_matched, in_dims + s_dims))
# Dimensions order permutation
dims_reorder = [d for m in matched_dims for d in m] + extra_dims
# Output final shape
res_shape = ([res.shape[d1] * res.shape[d2] for d1, d2 in matched_dims] +
[res.shape[d] for d in extra_dims])
return res.transpose(dims_reorder).reshape(res_shape)
input_ = np.array([[0, 1],
[2, 3]])
subs = np.array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]])
output = expand_from(input_, subs)
print(output)
# [[ 0 1 4 5]
# [ 2 3 6 7]
# [ 8 9 12 13]
# [10 11 14 15]]