“反转”数组,即将2d索引列表转换为1d索引的2d数组

时间:2018-12-22 01:54:04

标签: python arrays numpy indices

我想以一种最好的numpythonic方式解决的问题是: 我有一个二维索引列表A,例如:

A = [(0, 3), (2, 2), (3, 1)]

我的目标是现在获得一个数组

[[H H H 0],
 [H H H H],
 [H H 1 H],
 [H 2 H H]]

其中H是一些默认值(例如-1) 因此,问题通常在于以这种方式反转数组。

如果A是内射词(没有值出现两次),我可以严格地声明它:

让A是一个二维索引的内射数组。    然后,生成一个二维数组B,使得B [i,j] = A.index((i,j))

或者对于A不一定是单射的:

让A是一个二维索引的内射数组。    然后,生成一个二维数组B,使得A [B [i,j]] =(i,j)

更具体地说,在非内射的情况下,我们可以通过附加的“决策”功能来解决这种情况。 说

A = [(0, 3), (2, 2), (3, 1), (0, 3)]

然后要解决位置0和3上的(0,3)之间的冲突,我想对等效索引应用一些函数以找到一个确定的值。

例如: 就我而言,具体来说,我有第二个数组C,其长度与A相同。 如果最后2d数组中一个“位置”的A中有多个候选项(2d索引),则选择的那个应该是A中1d索引使C中的值最小化的一个。

我希望这些例子能使问题解决。 谢谢您的帮助。

编辑:更多示例:

    A = [(0, 3), (2, 2), (3, 1)]
    print(my_dream_func(A, default=7)
    >>> [[7 7 7 0],
         [7 7 7 7],
         [7 7 1 7],
         [7 2 7 7]]

    A = [(0, 3), (2, 2), (3, 1), (0, 3)]
    print(my_dream_func(A, default=7))
    >>> Err: an index appears twice

这种情况的替代方案:

    def resolveFunc(indices):
        c = [0.5, 2.0, 3.4, -1.9]
        return(np.argmin(c[indices]))

    A = [(0, 3), (2, 2), (3, 1), (0, 3)]

    print(my_dream_func(A, resolveFunc, default=7))
    #now resolveFunc is executed on 0 and 3
    #because 0.5 > -1.9, 3 is chosen as the value for (0, 3)
    >>> [[7 7 7 3],
         [7 7 7 7],
         [7 7 1 7],
         [7 2 7 7]]

2 个答案:

答案 0 :(得分:1)

我将按照以下步骤进行操作:

In [11]: A = np.array([(0, 3), (2, 2), (3, 1)])

In [12]: a = np.full((len(A), len(A)), 7)  # here H = 7

In [13]: a
Out[13]:
array([[7, 7, 7, 7],
       [7, 7, 7, 7],
       [7, 7, 7, 7],
       [7, 7, 7, 7]])

In [14]: a[A[:, 0], A[:, 1]] = np.arange(len(A))

In [15]: a
Out[15]:
array([[7, 7, 7, 0],
       [7, 7, 7, 7],
       [7, 7, 1, 7],
       [7, 2, 7, 7]])

“决定者”功能是最后的胜利。

如果要选择其他决策函数,则可以先指定/修改元组列表(和枚举),而不要尝试在numpy中做一些聪明的事情...

答案 1 :(得分:1)

Numpy支持将多个值同时分配给多个indizes。 因此,使用这种最灵活的方式编写函数的方式将是:

import numpy as np

def f(idx, shape, default):
    arr = np.full(shape, default)
    arr[idx] = np.arange(0, len(idx))
    return arr

shape=(4,4)
default=7
idx=[(1,2),(0,3)]

print(f(idx, shape, default))

如果idx中有重复项,最后一个索引元组会覆盖所有前任。