在numpy中插入三维矩阵中的非对齐元素

时间:2014-10-08 02:34:56

标签: python performance numpy matrix multidimensional-array

我使用 numpy 1.9 python 2.7.5 处理三维矩阵。 这是一个例子:

>>> A
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]])

>>> B
array([[[-1., -1., -1.],
        [99., 100., 101.],
        [-1., -1., -1.],
        [-1., -1., -1.],
        [-1., -1., -1.]],

       [[-1., -1., -1.],
        [-1., -1., -1.],
        [102., 103., 104.],
        [-1., -1., -1.],
        [-1., -1., -1.]]])

>>> C
array([1, 2])

根据C,我想在A中插入所有元素。 示例:c[0] = 1 => After A[0, 1, :] has to be inserted B[0, 1, :]

以下是预期结果的示例

>>> D
array([[[1.,  1.,  1.],
        [1.,  1.,  1.],
        [99., 100., 101.],
        [1.,  1.,  1.],
        [1.,  1.,  1.],
        [1.,  1.,  1.]],

       [[1.,  1.,  1.],
        [1.,  1.,  1.],
        [1.,  1.,  1.],
        [102., 103., 104.]
        [1.,  1.,  1.],
        [1.,  1.,  1.]]])

我发现this stackoverflow question与我的相似,只是解决方案仅适用于二维矩阵,而且我使用的是三维矩阵。

这是我的解决方案,但结果不正确:

C2 = np.repeat(C, 3)
r1 = np.repeat(np.arange(A.shape[0]), 3)
r2 = np.tile(np.arange(3), A.shape[0])
index_map = np.ravel_multi_index((r1, C2, r2), A.shape) + 1
np.insert(A.ravel(), index_map, B.ravel()[index_map]).reshape(A.shape[0], A.shape[1] + 1, A.shape[2])

这是一个使用for循环的正确但缓慢的解决方案:

A_2 = np.zeros((A.shape[0], A.shape[1] + 1, A.shape[2]))
for j in xrange(np.size(C, 0)):
  i = C[j]
  A_2[j, :, :] = np.concatenate((A[j, 0:i + 1, :], [B[j, i, :]], A[j, i + 1:, :]))

有什么想法吗?

谢谢!

2 个答案:

答案 0 :(得分:2)

这似乎可以替代你的(非工作)矢量化解决方案的最后一行:

linear = np.insert(A.ravel(), index_map + r2[::-1], B.ravel()[index_map - 1])
linear.reshape(A.shape[0], A.shape[1] + 1, A.shape[2])

这就像你的矢量化解决方案,但有一些调整,以使索引正确。第一个关键是要意识到我需要“撤消”你对index_map的添加1。下一个顿悟是当你插入linear时,你需要偏移每一行中的索引,因为当你插入元素时,后续的元素会被移回。因此,虽然index_map[4,5,6,22,23,24],但实际上我们需要[6,6,6,24,24,24],而我只是为了这个目的而重复使用/滥用r2[::-1]

B.ravel()[index_map - 1]似乎也可以简化为B[r1,C2,r2]。为了消除r2[::-1]的奇怪减法,稍微简化一下就可以了:

C2 = np.repeat(C, 3)
r1 = np.repeat(np.arange(A.shape[0]), 3)
r2 = np.repeat(2, A.shape[0] * A.shape[2])
index_map = np.ravel_multi_index((r1, C2, r2), A.shape) + 1
linear = np.insert(A.ravel(), index_map, B[r1,C2,r2])
linear.reshape(A.shape[0], A.shape[1] + 1, A.shape[2])

答案 1 :(得分:2)

您的代码存在的问题是,当您需要插入多个代码时 元素顺序,您需要将它们插入相同的位置。 比较:

In [139]: x = np.ones(5); x
Out[139]: array([ 1.,  1.,  1.,  1.,  1.])

In [140]: np.insert(x, [1,2,3], 100)
Out[140]: array([   1.,  100.,    1.,  100.,    1.,  100.,    1.,    1.])

In [141]: np.insert(x, [1,1,1], 100)
Out[141]: array([   1.,  100.,  100.,  100.,    1.,    1.,    1.,    1.])

编辑:原始答案包括完整的解散/重塑 回来了,但在3D中你需要很多照顾才能做到这一点。有一个 更简单的解决方案,考虑到np.insertnp.take的事实 mi+1接受“axis”参数并允许多值插入。 这仍然需要一些重塑,但它没有诉诸 np.choose。另外,请注意要插入的np.insert In [50]: mi = np.ravel_multi_index([np.arange(A.shape[0]), C], A.shape[:2]); mi Out[50]: array([1, 7]) In [51]: bvals = np.take(B.reshape(-1, B.shape[-1]), mi, axis=0); bvals Out[51]: array([[ 99., 100., 101.], [ 102., 103., 104.]]) In [52]: result = (np.insert(A.reshape(-1, A.shape[2]), mi + 1, bvals, axis=0) .reshape(A.shape[0], -1, A.shape[2])); result Out[52]: array([[[ 1., 1., 1.], [ 1., 1., 1.], [ 99., 100., 101.], [ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]], [[ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.], [ 102., 103., 104.], [ 1., 1., 1.], [ 1., 1., 1.]]]) 参数 之后,而不是在所选行之前:

In [18]: ixs = np.repeat(np.array([np.arange(A.shape[0]),
                                    C+1,
                                    np.zeros(A.shape[0], dtype=np.int_)]),
                          A.shape[2], axis=1); ixs
   ....: 
Out[18]: 
array([[0, 0, 0, 1, 1, 1],
       [2, 2, 2, 3, 3, 3],
       [0, 0, 0, 0, 0, 0]])

In [19]: mi = np.ravel_multi_index(ixs, A.shape); mi
Out[19]: array([ 6,  6,  6, 24, 24, 24])

In [20]: result = (np.insert(A.ravel(), mi, bvals)
                    .reshape(A.shape[0], A.shape[1] +1, A.shape[2])); result
   ....: 
Out[20]: 
array([[[   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [  99.,  100.,  101.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.]],

       [[   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [ 102.,  103.,  104.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.]]])

In [21]: result = (np.insert(A.ravel(), mi, bvals)
                    .reshape(A.shape[0], A.shape[1] +1, A.shape[2])); result
   ....: 
Out[21]: 
array([[[   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [  99.,  100.,  101.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.]],

       [[   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.],
        [ 102.,  103.,  104.],
        [   1.,    1.,    1.],
        [   1.,    1.,    1.]]])

这是最初的答案:

{{1}}