numpy的:ndarray.choose增加数组大小?

时间:2019-01-10 10:06:28

标签: python numpy numpy-ndarray

是否可以对多维数组的查找表使用ndarray.choose(或其他ndarray机制),以便输入数组的每个元素在输出数组中生成多个元素?例如。

input = array([[0, 1],
               [2, 3]])
subs = array([[[ 0,  1],
               [ 2,  3]],
              [[ 4,  5],
               [ 6,  7]],
              [[ 8,  9],
               [10, 11]],
              [[12, 13],
               [14, 15]]])
output = array([[ 0,  1,  4,  5],
                [ 2,  3,  6,  7],
                [ 8,  9, 12, 13],
                [10, 11, 14, 15]])

1 个答案:

答案 0 :(得分:1)

在此示例中,您可以按照以下步骤进行操作:

import numpy as np

input_ = np.array([[0, 1],
                   [2, 3]])
subs = np.array([[[ 0,  1],
                  [ 2,  3]],
                 [[ 4,  5],
                  [ 6,  7]],
                 [[ 8,  9],
                  [10, 11]],
                 [[12, 13],
                  [14, 15]]])
res = subs[input_].transpose((0, 2, 1, 3)).reshape((4, 4))
print(res)
# [[ 0  1  4  5]
#  [ 2  3  6  7]
#  [ 8  9 12 13]
#  [10 11 14 15]]

编辑:

一种更通用的解决方案,支持更多维度以及具有不同维度数量的输入和替换:

import numpy as np

def expand_from(input_, subs):
    input_= np.asarray(input_)
    subs = np.asarray(subs)
    # Take from subs according to input
    res = subs[input_]
    # Input dimensions
    in_dims = input_.ndim
    # One dimension of subs is for indexing
    s_dims = subs.ndim - 1
    # Dimensions that correspond to each other on output
    num_matched = min(in_dims, s_dims)
    matched_dims = [(i, in_dims + i) for i in range(num_matched)]
    # Additional dimensions if there are any
    if in_dims > s_dims:
        extra_dims = list(range(num_matched, in_dims))
    else:
        extra_dims = list(range(2 * num_matched, in_dims + s_dims))
    # Dimensions order permutation
    dims_reorder = [d for m in matched_dims for d in m] + extra_dims
    # Output final shape
    res_shape = ([res.shape[d1] * res.shape[d2] for d1, d2 in matched_dims] +
                 [res.shape[d] for d in extra_dims])
    return res.transpose(dims_reorder).reshape(res_shape)

input_ = np.array([[0, 1],
                   [2, 3]])
subs = np.array([[[ 0,  1],
                  [ 2,  3]],
                 [[ 4,  5],
                  [ 6,  7]],
                 [[ 8,  9],
                  [10, 11]],
                 [[12, 13],
                  [14, 15]]])
output = expand_from(input_, subs)
print(output)
# [[ 0  1  4  5]
#  [ 2  3  6  7]
#  [ 8  9 12 13]
#  [10 11 14 15]]